Source code for autorag.nodes.passagefilter.threshold_cutoff

from typing import List, Tuple

import pandas as pd

from autorag.nodes.passagefilter.base import BasePassageFilter
from autorag.utils.util import convert_inputs_to_list, result_to_dataframe


[docs] class ThresholdCutoff(BasePassageFilter):
[docs] @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) def pure(self, previous_result: pd.DataFrame, *args, **kwargs): _, contents, scores, ids = self.cast_to_run(previous_result) return self._pure(contents, scores, ids, *args, **kwargs)
def _pure( self, contents_list: List[List[str]], scores_list: List[List[float]], ids_list: List[List[str]], threshold: float, reverse: bool = False, ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: """ Filters the contents, scores, and ids based on a previous result's score. Keeps at least one item per query if all scores are below the threshold. :param contents_list: List of content strings for each query. :param scores_list: List of scores for each content. :param ids_list: List of ids for each content. :param threshold: The minimum score to keep an item. :param reverse: If True, the lower the score, the better. Default is False. :return: Filtered lists of contents, ids, and scores. """ remain_indices = list( map(lambda x: self.__row_pure(x, threshold, reverse), scores_list) ) remain_content_list = list( map(lambda c, idx: [c[i] for i in idx], contents_list, remain_indices) ) remain_scores_list = list( map(lambda s, idx: [s[i] for i in idx], scores_list, remain_indices) ) remain_ids_list = list( map(lambda _id, idx: [_id[i] for i in idx], ids_list, remain_indices) ) return remain_content_list, remain_ids_list, remain_scores_list @convert_inputs_to_list def __row_pure( self, scores_list: List[float], threshold: float, reverse: bool = False ) -> List[int]: """ Return indices that have to remain. Return at least one index if there is nothing to remain. :param scores_list: Each score :param threshold: The threshold to cut off :param reverse: If True, the lower the score, the better Default is False. :return: Indices to remain at the contents """ assert isinstance(scores_list, list), "scores_list must be a list." if reverse: remain_indices = [ i for i, score in enumerate(scores_list) if score <= threshold ] default_index = scores_list.index(min(scores_list)) else: remain_indices = [ i for i, score in enumerate(scores_list) if score >= threshold ] default_index = scores_list.index(max(scores_list)) return remain_indices if remain_indices else [default_index]