[docs]classFlashRankReranker(BasePassageReranker):def__init__(self,project_dir:str,model:str="ms-marco-TinyBERT-L-2-v2",*args,**kwargs):""" Initialize FlashRank rerank node. :param project_dir: The project directory path. :param model: The model name for FlashRank rerank. You can get the list of available models from https://github.com/PrithivirajDamodaran/FlashRank. Default is "ms-marco-TinyBERT-L-2-v2". Not support “rank_zephyr_7b_v1_full” due to parallel inference issue. :param kwargs: Extra arguments that are not affected """super().__init__(project_dir)try:fromtokenizersimportTokenizerexceptImportError:raiseImportError("Tokenizer is not installed. Please install tokenizers to use FlashRank reranker.")cache_dir=kwargs.pop("cache_dir","/tmp")max_length=kwargs.pop("max_length",512)self.cache_dir:Path=Path(cache_dir)self.model_dir:Path=self.cache_dir/modelself._prepare_model_dir(model)model_file=model_file_map[model]try:importonnxruntimeasortexceptImportError:raiseImportError("onnxruntime is not installed. Please install onnxruntime to use FlashRank reranker.")self.session=ort.InferenceSession(str(self.model_dir/model_file))self.tokenizer:Tokenizer=self._get_tokenizer(max_length)def__del__(self):delself.sessiondelself.tokenizerempty_cuda_cache()super().__del__()def_prepare_model_dir(self,model_name:str):ifnotself.cache_dir.exists():self.cache_dir.mkdir(parents=True,exist_ok=True)ifnotself.model_dir.exists():self._download_model_files(model_name)def_download_model_files(self,model_name:str):local_zip_file=self.cache_dir/f"{model_name}.zip"formatted_model_url=model_url.format(model_name)withrequests.get(formatted_model_url,stream=True)asr:r.raise_for_status()total_size=int(r.headers.get("content-length",0))with(open(local_zip_file,"wb")asf,tqdm(desc=local_zip_file.name,total=total_size,unit="iB",unit_scale=True,unit_divisor=1024,)asbar,):forchunkinr.iter_content(chunk_size=8192):size=f.write(chunk)bar.update(size)withzipfile.ZipFile(local_zip_file,"r")aszip_ref:zip_ref.extractall(self.cache_dir)os.remove(local_zip_file)def_get_tokenizer(self,max_length:int=512):try:fromtokenizersimportAddedToken,TokenizerexceptImportError:raiseImportError("Pytorch is not installed. Please install pytorch to use FlashRank reranker.")config=json.load(open(str(self.model_dir/"config.json")))tokenizer_config=json.load(open(str(self.model_dir/"tokenizer_config.json")))tokens_map=json.load(open(str(self.model_dir/"special_tokens_map.json")))tokenizer=Tokenizer.from_file(str(self.model_dir/"tokenizer.json"))tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"],max_length))tokenizer.enable_padding(pad_id=config["pad_token_id"],pad_token=tokenizer_config["pad_token"])fortokenintokens_map.values():ifisinstance(token,str):tokenizer.add_special_tokens([token])elifisinstance(token,dict):tokenizer.add_special_tokens([AddedToken(**token)])vocab_file=self.model_dir/"vocab.txt"ifvocab_file.exists():tokenizer.vocab=self._load_vocab(vocab_file)tokenizer.ids_to_tokens=collections.OrderedDict([(ids,tok)fortok,idsintokenizer.vocab.items()])returntokenizerdef_load_vocab(self,vocab_file:Path)->Dict[str,int]:vocab=collections.OrderedDict()withopen(vocab_file,"r",encoding="utf-8")asreader:tokens=reader.readlines()forindex,tokeninenumerate(tokens):token=token.rstrip("\n")vocab[token]=indexreturnvocab
def_pure(self,queries:List[str],contents_list:List[List[str]],ids_list:List[List[str]],top_k:int,batch:int=64,)->Tuple[List[List[str]],List[List[str]],List[List[float]]]:""" Rerank a list of contents with FlashRank rerank models. :param queries: The list of queries to use for reranking :param contents_list: The list of lists of contents to rerank :param ids_list: The list of lists of ids retrieved from the initial ranking :param top_k: The number of passages to be retrieved :param batch: The number of queries to be processed in a batch :return: Tuple of lists containing the reranked contents, ids, and scores """nested_list=[list(map(lambdax:[query,x],content_list))forquery,content_listinzip(queries,contents_list)]rerank_scores=flatten_apply(flashrank_run_model,nested_list,session=self.session,batch_size=batch,tokenizer=self.tokenizer,)df=pd.DataFrame({"contents":contents_list,"ids":ids_list,"scores":rerank_scores,})df[["contents","ids","scores"]]=df.apply(sort_by_scores,axis=1,result_type="expand")results=select_top_k(df,["contents","ids","scores"],top_k)return(results["contents"].tolist(),results["ids"].tolist(),results["scores"].tolist(),)