Source code for autorag.nodes.retrieval.base

import functools
import logging
import os
import pickle
from pathlib import Path
from typing import List, Union, Tuple, Dict

import chromadb
import pandas as pd
import torch

from autorag import embedding_models
from import get_support_modules
from autorag.utils import fetch_contents, result_to_dataframe, validate_qa_dataset

logger = logging.getLogger("AutoRAG")

[docs] def retrieval_node(func): """ Load resources for running retrieval_node. For example, it loads bm25 corpus for bm25 retrieval. :param func: Retrieval function that returns a list of ids and a list of scores :return: A pandas Dataframe that contains retrieved contents, retrieved ids, and retrieve scores. The column name will be "retrieved_contents", "retrieved_ids", and "retrieve_scores". """ @functools.wraps(func) @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) def wrapper( project_dir: Union[str, Path], previous_result: pd.DataFrame, **kwargs ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: validate_qa_dataset(previous_result) resources_dir = os.path.join(project_dir, "resources") data_dir = os.path.join(project_dir, "data") if func.__name__ == "bm25": # check if bm25_path and file exists bm25_tokenizer = kwargs.get("bm25_tokenizer", None) if bm25_tokenizer is None: bm25_tokenizer = "porter_stemmer" bm25_path = os.path.join(resources_dir, get_bm25_pkl_name(bm25_tokenizer)) assert ( bm25_path is not None ), "bm25_path must be specified for using bm25 retrieval." assert os.path.exists( bm25_path ), f"bm25_path {bm25_path} does not exist. Please ingest first." elif func.__name__ == "vectordb": # check if chroma_path and file exist chroma_path = os.path.join(resources_dir, "chroma") embedding_model_str = kwargs.pop("embedding_model") assert ( chroma_path is not None ), "chroma_path must be specified for using vectordb retrieval." assert os.path.exists( chroma_path ), f"chroma_path {chroma_path} does not exist. Please ingest first." # find queries columns & type cast queries assert ( "query" in previous_result.columns ), "previous_result must have query column." if "queries" not in previous_result.columns: previous_result["queries"] = previous_result["query"] previous_result.loc[:, "queries"] = previous_result["queries"].apply( cast_queries ) queries = previous_result["queries"].tolist() # run retrieval function if func.__name__ == "bm25": bm25_corpus = load_bm25_corpus(bm25_path) ids, scores = func(queries=queries, bm25_corpus=bm25_corpus, **kwargs) elif func.__name__ == "vectordb": chroma_collection = load_chroma_collection( db_path=chroma_path, collection_name=embedding_model_str ) if embedding_model_str in embedding_models: embedding_model = embedding_models[embedding_model_str] else: logger.error( f"embedding_model_str {embedding_model_str} does not exist." ) raise KeyError( f"embedding_model_str {embedding_model_str} does not exist." ) ids, scores = func( queries=queries, collection=chroma_collection, embedding_model=embedding_model, **kwargs, ) del embedding_model if torch.cuda.is_available(): torch.cuda.empty_cache() elif func.__name__ in ["hybrid_rrf", "hybrid_cc"]: if "ids" in kwargs and "scores" in kwargs: # ordinary run_evaluate ids, scores = func(**kwargs) else: # => for if not ( "target_modules" in kwargs and "target_module_params" in kwargs ): raise ValueError( f"If there are no ids and scores specified, target_modules and target_module_params must be specified for using {func.__name__}." ) target_modules = kwargs.pop("target_modules") target_module_params = kwargs.pop("target_module_params") result_dfs = list( map( lambda x: get_support_modules(x[0])( **x[1], project_dir=project_dir, previous_result=previous_result, ), zip(target_modules, target_module_params), ) ) ids = tuple( map(lambda df: df["retrieved_ids"].apply(list).tolist(), result_dfs) ) scores = tuple( map( lambda df: df["retrieve_scores"].apply(list).tolist(), result_dfs, ) ) ids, scores = func(ids=ids, scores=scores, **kwargs) else: raise ValueError("invalid func name for using retrieval_io decorator.") # fetch data from corpus_data corpus_data = pd.read_parquet( os.path.join(data_dir, "corpus.parquet"), engine="pyarrow" ) contents = fetch_contents(corpus_data, ids) return contents, ids, scores return wrapper
[docs] def load_bm25_corpus(bm25_path: str) -> Dict: if bm25_path is None: return {} with open(bm25_path, "rb") as f: bm25_corpus = pickle.load(f) return bm25_corpus
[docs] def load_chroma_collection(db_path: str, collection_name: str) -> chromadb.Collection: db = chromadb.PersistentClient(path=db_path) collection = db.get_collection(name=collection_name) return collection
[docs] def cast_queries(queries: Union[str, List[str]]) -> List[str]: if isinstance(queries, str): return [queries] elif isinstance(queries, List): return queries else: raise ValueError(f"queries must be str or list, but got {type(queries)}")
[docs] def evenly_distribute_passages( ids: List[List[str]], scores: List[List[float]], top_k: int ) -> Tuple[List[str], List[float]]: assert len(ids) == len(scores), "ids and scores must have same length." query_cnt = len(ids) avg_len = top_k // query_cnt remainder = top_k % query_cnt new_ids = [] new_scores = [] for i in range(query_cnt): if i < remainder: new_ids.extend(ids[i][: avg_len + 1]) new_scores.extend(scores[i][: avg_len + 1]) else: new_ids.extend(ids[i][:avg_len]) new_scores.extend(scores[i][:avg_len]) return new_ids, new_scores
[docs] def get_bm25_pkl_name(bm25_tokenizer: str): bm25_tokenizer = bm25_tokenizer.replace("/", "") return f"bm25_{bm25_tokenizer}.pkl"