[docs]defcast_to_run(self,previous_result:pd.DataFrame,*args,**kwargs):logger.info(f"Running passage augmenter node - {self.__class__.__name__} module...")validate_qa_dataset(previous_result)# find ids columnsassert("retrieved_ids"inprevious_result.columns),"previous_result must have retrieved_ids column."ids=previous_result["retrieved_ids"].tolist()returnids
[docs]@staticmethoddefsort_by_scores(augmented_contents,augmented_ids,augmented_scores,top_k:int,reverse:bool=True,):# sort by scoresdf=pd.DataFrame({"contents":augmented_contents,"ids":augmented_ids,"scores":augmented_scores,})df[["contents","ids","scores"]]=df.apply(lambdarow:sort_by_scores(row,reverse=reverse),axis=1,result_type="expand",)# select by top_kresults=select_top_k(df,["contents","ids","scores"],top_k)return(results["contents"].tolist(),results["ids"].tolist(),results["scores"].tolist(),)