import abc
import logging
import os
from typing import List, Union, Tuple
import pandas as pd
from autorag.schema import BaseModule
from autorag.support import get_support_modules
from autorag.utils import fetch_contents, result_to_dataframe, validate_qa_dataset
from autorag.utils.util import pop_params
logger = logging.getLogger("AutoRAG")
[docs]
class BaseRetrieval(BaseModule, metaclass=abc.ABCMeta):
def __init__(self, project_dir: str, *args, **kwargs):
logger.info(f"Initialize retrieval node - {self.__class__.__name__}")
self.resources_dir = os.path.join(project_dir, "resources")
data_dir = os.path.join(project_dir, "data")
# fetch data from corpus_data
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
def __del__(self):
logger.info(f"Deleting retrieval node - {self.__class__.__name__} module...")
[docs]
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
logger.info(f"Running retrieval node - {self.__class__.__name__} module...")
validate_qa_dataset(previous_result)
# find queries columns & type cast queries
assert (
"query" in previous_result.columns
), "previous_result must have query column."
if "queries" not in previous_result.columns:
previous_result["queries"] = previous_result["query"]
previous_result.loc[:, "queries"] = previous_result["queries"].apply(
cast_queries
)
queries = previous_result["queries"].tolist()
return queries
[docs]
class HybridRetrieval(BaseRetrieval, metaclass=abc.ABCMeta):
def __init__(
self, project_dir: str, target_modules, target_module_params, *args, **kwargs
):
super().__init__(project_dir)
self.target_modules = list(
map(
lambda x, y: get_support_modules(x)(
**y,
project_dir=project_dir,
),
target_modules,
target_module_params,
)
)
self.target_module_params = target_module_params
[docs]
@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
result_dfs: List[pd.DataFrame] = list(
map(
lambda x, y: x.pure(
**y,
previous_result=previous_result,
),
self.target_modules,
self.target_module_params,
)
)
ids = tuple(
map(lambda df: df["retrieved_ids"].apply(list).tolist(), result_dfs)
)
scores = tuple(
map(
lambda df: df["retrieve_scores"].apply(list).tolist(),
result_dfs,
)
)
_pure_params = pop_params(self._pure, kwargs)
if "ids" in _pure_params or "scores" in _pure_params:
raise ValueError(
"With specifying ids or scores, you must use HybridRRF.run_evaluator instead."
)
ids, scores = self._pure(ids=ids, scores=scores, **_pure_params)
contents = fetch_contents(self.corpus_df, ids)
return contents, ids, scores
[docs]
def cast_queries(queries: Union[str, List[str]]) -> List[str]:
if isinstance(queries, str):
return [queries]
elif isinstance(queries, List):
return queries
else:
raise ValueError(f"queries must be str or list, but got {type(queries)}")
[docs]
def evenly_distribute_passages(
ids: List[List[str]], scores: List[List[float]], top_k: int
) -> Tuple[List[str], List[float]]:
assert len(ids) == len(scores), "ids and scores must have same length."
query_cnt = len(ids)
avg_len = top_k // query_cnt
remainder = top_k % query_cnt
new_ids = []
new_scores = []
for i in range(query_cnt):
if i < remainder:
new_ids.extend(ids[i][: avg_len + 1])
new_scores.extend(scores[i][: avg_len + 1])
else:
new_ids.extend(ids[i][:avg_len])
new_scores.extend(scores[i][:avg_len])
return new_ids, new_scores
[docs]
def get_bm25_pkl_name(bm25_tokenizer: str):
bm25_tokenizer = bm25_tokenizer.replace("/", "")
return f"bm25_{bm25_tokenizer}.pkl"