Source code for autorag.nodes.passagecompressor.tree_summarize

from typing import List, Optional

from llama_index.core import PromptTemplate
from llama_index.core.prompts import PromptType
from llama_index.core.prompts.utils import is_chat_model
from llama_index.core.response_synthesizers import TreeSummarize as ts

from autorag.nodes.passagecompressor.base import LlamaIndexCompressor
from autorag.utils.util import get_event_loop, process_batch


[docs] class TreeSummarize(LlamaIndexCompressor): def _pure( self, queries: List[str], contents: List[List[str]], prompt: Optional[str] = None, chat_prompt: Optional[str] = None, batch: int = 16, ) -> List[str]: """ Recursively merge retrieved texts and summarizes them in a bottom-up fashion. This function is a wrapper for llama_index.response_synthesizers.TreeSummarize. For more information, visit https://docs.llamaindex.ai/en/latest/examples/response_synthesizers/tree_summarize.html. :param queries: The queries for retrieved passages. :param contents: The contents of retrieved passages. :param prompt: The prompt template for summarization. If you want to use chat prompt, you should pass chat_prompt instead. At prompt, you must specify where to put 'context_str' and 'query_str'. Default is None. When it is None, it will use llama index default prompt. :param chat_prompt: The chat prompt template for summarization. If you want to use normal prompt, you should pass prompt instead. At prompt, you must specify where to put 'context_str' and 'query_str'. Default is None. When it is None, it will use llama index default chat prompt. :param batch: The batch size for llm. Set low if you face some errors. Default is 16. :return: The list of compressed texts. """ if prompt is not None and not is_chat_model(self.llm): summary_template = PromptTemplate(prompt, prompt_type=PromptType.SUMMARY) elif chat_prompt is not None and is_chat_model(self.llm): summary_template = PromptTemplate( chat_prompt, prompt_type=PromptType.SUMMARY ) else: summary_template = None summarizer = ts(llm=self.llm, summary_template=summary_template, use_async=True) tasks = [ summarizer.aget_response(query, content) for query, content in zip(queries, contents) ] loop = get_event_loop() results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) return results