Source code for autorag.data.qa.filter.dontknow

from typing import Dict, List

from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse
from llama_index.llms.openai.utils import to_openai_message_dicts
from openai import AsyncClient
from pydantic import BaseModel

from autorag.data.qa.filter.prompt import FILTER_PROMPT

dont_know_phrases = {
	"en": [
		"I don't know",
		"I do not know",
		"Don't know",
		"Do not know",
	],
	"ko": [
		"몰라요",
		"모르겠습니다",
		"모르겠어요",
		"몰라",
		"내가 어떻게 알아?",
		"모르겠소",
		"몰라유",
		"모르것는디",
		"모르겠어유",
		"모르겠네유",
		"모르겠네요",
	],
	"ja": [
		"知りません",
		"わかりません",
		"分かりません",
		"知らないです",
		"よく分かってません",
		"わかりかねます",
		"存じません",
		"お答えいたしかねます",
	],
}


[docs] def dontknow_filter_rule_based(row: Dict, lang: str = "en") -> bool: assert ( "generation_gt" in row.keys() ), "generation_gt column is not in the DataFrame." dont_know_phrase = dont_know_phrases[lang] return not any( phrase in s for phrase in dont_know_phrase for s in row["generation_gt"] )
[docs] class Response(BaseModel): is_dont_know: bool
[docs] async def dontknow_filter_openai( row: Dict, client: AsyncClient, model_name: str = "gpt-4o-mini-2024-07-18", lang: str = "en", ) -> bool: """ This will drop rows that have a "don't know" answer. It will drop unanswerable questions from the QA dataset. You can use this filter with the ` batch_filter ` function at `QA` class. :param row: The row dict from QA dataset. :param client: The OpenAI client. :param model_name: The model name. You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18. :param lang: The supported language is en, ko or ja. :return: False if the row generation_gt is a "don't know" meaning. """ assert "generation_gt" in row.keys(), "generation_gt column is not in the row." system_prompt: List[ChatMessage] = FILTER_PROMPT["dontknow_filter"][lang] result = [] for gen_gt in row["generation_gt"]: completion = await client.beta.chat.completions.parse( model=model_name, messages=to_openai_message_dicts( system_prompt + [ChatMessage(role=MessageRole.USER, content=gen_gt)] ), response_format=Response, ) result.append(completion.choices[0].message.parsed.is_dont_know) return not any(result)
[docs] async def dontknow_filter_llama_index( row: Dict, llm: BaseLLM, lang: str = "en", ) -> bool: """ This will drop rows that have a "don't know" answer. It will drop unanswerable questions from the QA dataset. You can use this filter with the ` batch_filter ` function at `QA` class. :param row: The row dict from QA dataset. :param llm: The Llama index llm instance. It will be good if you set max tokens to low for saving tokens. :param lang: The supported language is en, ko or ja. :return: False if the row generation_gt is a "don't know" meaning. """ assert "generation_gt" in row.keys(), "generation_gt column is not in the row." system_prompt: List[ChatMessage] = FILTER_PROMPT["dontknow_filter"][lang] results = [] for gen_gt in row["generation_gt"]: response: ChatResponse = await llm.achat( messages=system_prompt + [ChatMessage(role=MessageRole.USER, content=gen_gt)] ) result_str = response.message.content results.append("true" in result_str.lower().strip()) return not any(results)