Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for autorag.nodes.queryexpansion.base
import abc
import logging
from pathlib import Path
from typing import List , Union
import pandas as pd
from autorag.nodes.util import make_generator_callable_param
from autorag.schema import BaseModule
from autorag.utils import validate_qa_dataset
logger = logging . getLogger ( "AutoRAG" )
[docs]
class BaseQueryExpansion ( BaseModule , metaclass = abc . ABCMeta ):
def __init__ ( self , project_dir : Union [ str , Path ], * args , ** kwargs ):
logger . info (
f "Initialize query expansion node - { self . __class__ . __name__ } module..."
)
# set generator module for query expansion
generator_class , generator_param = make_generator_callable_param ( kwargs )
self . generator = generator_class ( project_dir , ** generator_param )
def __del__ ( self ):
del self . generator
logger . info (
f "Delete query expansion node - { self . __class__ . __name__ } module..."
)
[docs]
def cast_to_run ( self , previous_result : pd . DataFrame , * args , ** kwargs ):
logger . info (
f "Running query expansion node - { self . __class__ . __name__ } module..."
)
validate_qa_dataset ( previous_result )
# find queries columns
assert (
"query" in previous_result . columns
), "previous_result must have query column."
queries = previous_result [ "query" ] . tolist ()
return queries
@staticmethod
def _check_expanded_query ( queries : List [ str ], expanded_queries : List [ List [ str ]]):
return list (
map (
lambda query , expanded_query_list : check_expanded_query (
query , expanded_query_list
),
queries ,
expanded_queries ,
)
)
[docs]
def check_expanded_query ( query : str , expanded_query_list : List [ str ]):
# check if the expanded query is the same as the original query
expanded_query_list = list ( map ( lambda x : x . strip (), expanded_query_list ))
return [
expanded_query if expanded_query else query
for expanded_query in expanded_query_list
]