Source code for autorag.nodes.promptmaker.long_context_reorder

import logging
from typing import List

import numpy as np
import pandas as pd

from autorag.nodes.promptmaker.base import BasePromptMaker
from autorag.utils import result_to_dataframe

logger = logging.getLogger("AutoRAG")


[docs] class LongContextReorder(BasePromptMaker):
[docs] @result_to_dataframe(["prompts"]) def pure(self, previous_result: pd.DataFrame, *args, **kwargs): query, retrieved_contents, prompt = self.cast_to_run( previous_result, *args, **kwargs ) assert ( "retrieve_scores" in previous_result.columns ), "previous_result must have retrieve_scores column." retrieve_scores = previous_result["retrieve_scores"].tolist() return self._pure(prompt, query, retrieved_contents, retrieve_scores)
def _pure( self, prompt: str, queries: List[str], retrieved_contents: List[List[str]], retrieve_scores: List[List[float]], ) -> List[str]: """ Models struggle to access significant details found in the center of extended contexts. A study (https://arxiv.org/abs/2307.03172) observed that the best performance typically arises when crucial data is positioned at the start or conclusion of the input context. Additionally, as the input context lengthens, performance drops notably, even in models designed for long contexts." .. Code:: yaml nodes: - node_type: prompt_maker modules: - module_type: long_context_reorder prompt: [Answer this question: {query} \n\n {retrieved_contents}, Read the passages carefully and answer this question: {query} \n\n Passages: {retrieved_contents}] :param prompt: A prompt string. :param queries: List of query strings. :param retrieved_contents: List of retrieved contents. :param retrieve_scores: List of `retrieve scores`. :return: Prompts that are made by long context reorder. """ def long_context_reorder_row( _prompt: str, _query: str, _retrieved_contents: List[str], _retrieve_scores: List[float], ) -> str: if isinstance(_retrieved_contents, np.ndarray): _retrieved_contents = _retrieved_contents.tolist() if not len(_retrieved_contents) == len(_retrieve_scores): logger.info("If you use a summarizer, the reorder will not proceed.") return _prompt.format( query=_query, retrieved_contents="\n\n".join(_retrieved_contents) ) content_scores = list(zip(_retrieved_contents, _retrieve_scores)) sorted_content_scores = sorted( content_scores, key=lambda x: x[1], reverse=True ) content_result, score_result = zip(*sorted_content_scores) _retrieved_contents.append(content_result[0]) contents_str = "\n\n".join(_retrieved_contents) return _prompt.format(query=_query, retrieved_contents=contents_str) return list( map( lambda x: long_context_reorder_row(prompt, x[0], x[1], x[2]), zip(queries, retrieved_contents, retrieve_scores), ) )