Source code for autorag.nodes.passageaugmenter.prev_next_augmenter

from typing import List, Union

import numpy as np
import pandas as pd

from autorag.embedding.base import EmbeddingModel
from autorag.evaluation.metric.util import calculate_cosine_similarity
from autorag.nodes.passageaugmenter.base import BasePassageAugmenter
from autorag.utils.util import (

[docs] class PrevNextPassageAugmenter(BasePassageAugmenter): def __init__( self, project_dir: str, embedding_model: Union[str, dict] = "openai", *args, **kwargs, ): """ Initialize the PrevNextPassageAugmenter module. :param project_dir: :param embedding_model: The embedding model name to use for calculating cosine similarity Default is openai (text-embedding-ada-002) :param kwargs: """ super().__init__(project_dir, *args, **kwargs) slim_corpus_df = self.corpus_df[["doc_id", "metadata"]] slim_corpus_df.loc[:, "metadata"] = slim_corpus_df["metadata"].apply( filter_dict_keys, keys=["prev_id", "next_id"] ) self.slim_corpus_df = slim_corpus_df # init embedding model self.embedding_model = EmbeddingModel.load(embedding_model)() def __del__(self): del self.embedding_model empty_cuda_cache() super().__del__()
[docs] @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) def pure(self, previous_result: pd.DataFrame, *args, **kwargs): """ Run the passage augmenter node - PrevNextPassageAugmenter module. :param previous_result: The previous result Dataframe. :param top_k: You must input the top_k value to get the top k results. :param kwargs: Not affected. :return: DataFrame with retrieved_contents, retrieved_ids, and retrieve_scores columns """ top_k = kwargs.pop("top_k") ids = self.cast_to_run(previous_result) # find queries columns assert ( "query" in previous_result.columns ), "previous_result must have query column." queries = previous_result["query"].tolist() mode = kwargs.pop("mode", "both") num_passages = kwargs.pop("num_passages", 1) augmented_ids = self._pure(ids, num_passages, mode) # fetch contents from corpus to use augmented ids augmented_contents = fetch_contents(self.corpus_df, augmented_ids) query_embeddings, contents_embeddings = embedding_query_content( queries, augmented_contents, self.embedding_model, batch=128 ) # get scores from calculated cosine similarity augmented_scores = [ np.array( [ calculate_cosine_similarity(query_embedding, x) for x in content_embeddings ] ).tolist() for query_embedding, content_embeddings in zip( query_embeddings, contents_embeddings ) ] return self.sort_by_scores( augmented_contents, augmented_ids, augmented_scores, top_k )
def _pure( self, ids_list: List[List[str]], num_passages: int = 1, mode: str = "both", ) -> List[List[str]]: """ Add passages before and/or after the retrieved passage. For more information, visit :param ids_list: The list of lists of ids retrieved :param num_passages: The number of passages to add before and after the retrieved passage Default is 1. :param mode: The mode of augmentation 'prev': add passages before the retrieved passage 'next': add passages after the retrieved passage 'both': add passages before and after the retrieved passage Default is 'next'. :return: The list of lists of augmented ids """ if mode not in ["prev", "next", "both"]: raise ValueError(f"mode must be 'prev', 'next', or 'both', but got {mode}") augmented_ids = [ ( lambda ids: prev_next_augmenter_pure( ids, self.slim_corpus_df, mode, num_passages ) )(ids) for ids in ids_list ] return augmented_ids
[docs] def prev_next_augmenter_pure( ids: List[str], corpus_df: pd.DataFrame, mode: str, num_passages: int ): def fetch_id_sequence(start_id, key): sequence = [] current_id = start_id for _ in range(num_passages): current_id = ( corpus_df.loc[corpus_df["doc_id"] == current_id]["metadata"] .values[0] .get(key) ) if current_id is None: break sequence.append(current_id) return sequence augmented_group = [] for id_ in ids: current_ids = [id_] if mode in ["prev", "both"]: current_ids = fetch_id_sequence(id_, "prev_id")[::-1] + current_ids if mode in ["next", "both"]: current_ids += fetch_id_sequence(id_, "next_id") augmented_group.extend(current_ids) return augmented_group