Source code for autorag.nodes.passagecompressor.base

import functools
import logging
from pathlib import Path
from typing import List, Union, Dict

import pandas as pd
from llama_index.core.llms import LLM

from autorag import generator_models
from autorag.utils import result_to_dataframe

logger = logging.getLogger("AutoRAG")


[docs] def passage_compressor_node(func): @functools.wraps(func) @result_to_dataframe(["retrieved_contents"]) def wrapper( project_dir: Union[str, Path], previous_result: pd.DataFrame, *args, **kwargs ) -> List[List[str]]: logger.info(f"Running generator node - {func.__name__} module...") assert all( [ column in previous_result.columns for column in [ "query", "retrieved_contents", "retrieved_ids", "retrieve_scores", ] ] ), "previous_result must have retrieved_contents, retrieved_ids, and retrieve_scores columns." assert len(previous_result) > 0, "previous_result must have at least one row." queries = previous_result["query"].tolist() retrieved_contents = previous_result["retrieved_contents"].tolist() retrieved_ids = previous_result["retrieved_ids"].tolist() retrieve_scores = previous_result["retrieve_scores"].tolist() if func.__name__ in ["tree_summarize", "refine"]: param_list = [ "prompt", "chat_prompt", "context_window", "num_output", "batch", ] param_dict = dict(filter(lambda x: x[0] in param_list, kwargs.items())) kwargs_dict = dict(filter(lambda x: x[0] not in param_list, kwargs.items())) llm_name = kwargs_dict.pop("llm") llm = make_llm(llm_name, kwargs_dict) result = func( queries=queries, contents=retrieved_contents, scores=retrieve_scores, ids=retrieved_ids, llm=llm, **param_dict, ) del llm result = list(map(lambda x: [x], result)) elif func.__name__ == "longllmlingua": result = func( queries=queries, contents=retrieved_contents, scores=retrieve_scores, ids=retrieved_ids, **kwargs, ) result = list(map(lambda x: [x], result)) elif func.__name__ == "pass_compressor": result = func(contents=retrieved_contents) else: raise ValueError( f"{func.__name__} is not supported in passage compressor node." ) return result return wrapper
[docs] def make_llm(llm_name: str, kwargs: Dict) -> LLM: if llm_name not in generator_models: raise KeyError( f"{llm_name} is not supported. " "You can add it manually by calling autorag.generator_models." ) return generator_models[llm_name](**kwargs)