Source code for autorag.data.qa.generation_gt.llama_index_gen_gt
import itertools
from typing import Dict
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import MessageRole, ChatMessage
from autorag.data.qa.generation_gt.base import add_gen_gt
from autorag.data.qa.generation_gt.prompt import GEN_GT_SYSTEM_PROMPT
[docs]
async def make_gen_gt_llama_index(row: Dict, llm: BaseLLM, system_prompt: str) -> Dict:
retrieval_gt_contents = list(
itertools.chain.from_iterable(row["retrieval_gt_contents"])
)
query = row["query"]
passage_str = "\n".join(retrieval_gt_contents)
user_prompt = f"Text:\n<|text_start|>\n{passage_str}\n<|text_end|>\n\nQuestion:\n{query}\n\nAnswer:"
response = await llm.achat(
messages=[
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=MessageRole.USER, content=user_prompt),
],
temperature=0.0,
)
return add_gen_gt(row, response.message.content)
[docs]
async def make_concise_gen_gt(row: Dict, llm: BaseLLM, lang: str = "en") -> Dict:
return await make_gen_gt_llama_index(
row, llm, GEN_GT_SYSTEM_PROMPT["concise"][lang]
)
[docs]
async def make_basic_gen_gt(row: Dict, llm: BaseLLM, lang: str = "en") -> Dict:
return await make_gen_gt_llama_index(row, llm, GEN_GT_SYSTEM_PROMPT["basic"][lang])
[docs]
async def make_custom_gen_gt(row: Dict, llm: BaseLLM, system_prompt: str) -> Dict:
return await make_gen_gt_llama_index(row, llm, system_prompt)