Source code for autorag.nodes.passagecompressor.refine

from typing import List, Optional

from llama_index.core import PromptTemplate
from llama_index.core.prompts import PromptType
from llama_index.core.prompts.utils import is_chat_model
from llama_index.core.response_synthesizers import Refine as rf

from autorag.nodes.passagecompressor.base import LlamaIndexCompressor
from autorag.utils.util import get_event_loop, process_batch


[docs] class Refine(LlamaIndexCompressor): def _pure( self, queries: List[str], contents: List[List[str]], prompt: Optional[str] = None, chat_prompt: Optional[str] = None, batch: int = 16, ) -> List[str]: """ Refine a response to a query across text chunks. This function is a wrapper for llama_index.response_synthesizers.Refine. For more information, visit https://docs.llamaindex.ai/en/stable/examples/response_synthesizers/refine/. :param queries: The queries for retrieved passages. :param contents: The contents of retrieved passages. :param prompt: The prompt template for refine. If you want to use chat prompt, you should pass chat_prompt instead. At prompt, you must specify where to put 'context_msg' and 'query_str'. Default is None. When it is None, it will use llama index default prompt. :param chat_prompt: The chat prompt template for refine. If you want to use normal prompt, you should pass prompt instead. At prompt, you must specify where to put 'context_msg' and 'query_str'. Default is None. When it is None, it will use llama index default chat prompt. :param batch: The batch size for llm. Set low if you face some errors. Default is 16. :return: The list of compressed texts. """ if prompt is not None and not is_chat_model(self.llm): refine_template = PromptTemplate(prompt, prompt_type=PromptType.REFINE) elif chat_prompt is not None and is_chat_model(self.llm): refine_template = PromptTemplate(chat_prompt, prompt_type=PromptType.REFINE) else: refine_template = None summarizer = rf(llm=self.llm, refine_template=refine_template, verbose=True) tasks = [ summarizer.aget_response(query, content) for query, content in zip(queries, contents) ] loop = get_event_loop() results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) return results