fromtypingimportDict,Listfromllama_index.core.base.llms.baseimportBaseLLMfromllama_index.core.base.llms.typesimportChatMessage,MessageRole,ChatResponsefromllama_index.llms.openai.utilsimportto_openai_message_dictsfromopenaiimportAsyncClientfrompydanticimportBaseModelfromautorag.data.qa.filter.promptimportFILTER_PROMPTdont_know_phrases={"en":["I don't know","I do not know","Don't know","Do not know",],"ko":["몰라요","모르겠습니다","모르겠어요","몰라","내가 어떻게 알아?","모르겠소","몰라유","모르것는디","모르겠어유","모르겠네유","모르겠네요",],"ja":["知りません","わかりません","分かりません","知らないです","よく分かってません","わかりかねます","存じません","お答えいたしかねます"]}
[docs]defdontknow_filter_rule_based(row:Dict,lang:str="en")->bool:assert("generation_gt"inrow.keys()),"generation_gt column is not in the DataFrame."dont_know_phrase=dont_know_phrases[lang]returnnotany(phraseinsforphraseindont_know_phraseforsinrow["generation_gt"])
[docs]asyncdefdontknow_filter_openai(row:Dict,client:AsyncClient,model_name:str="gpt-4o-mini-2024-07-18",lang:str="en",)->bool:""" This will drop rows that have a "don't know" answer. It will drop unanswerable questions from the QA dataset. You can use this filter with the ` batch_filter ` function at `QA` class. :param row: The row dict from QA dataset. :param client: The OpenAI client. :param model_name: The model name. You have to use gpt-4o-2024-08-06 or gpt-4o-mini-2024-07-18. :param lang: The supported language is en, ko or ja. :return: False if the row generation_gt is a "don't know" meaning. """assert"generation_gt"inrow.keys(),"generation_gt column is not in the row."system_prompt:List[ChatMessage]=FILTER_PROMPT["dontknow_filter"][lang]result=[]forgen_gtinrow["generation_gt"]:completion=awaitclient.beta.chat.completions.parse(model=model_name,messages=to_openai_message_dicts(system_prompt+[ChatMessage(role=MessageRole.USER,content=gen_gt)]),response_format=Response,)result.append(completion.choices[0].message.parsed.is_dont_know)returnnotany(result)
[docs]asyncdefdontknow_filter_llama_index(row:Dict,llm:BaseLLM,lang:str="en",)->bool:""" This will drop rows that have a "don't know" answer. It will drop unanswerable questions from the QA dataset. You can use this filter with the ` batch_filter ` function at `QA` class. :param row: The row dict from QA dataset. :param llm: The Llama index llm instance. It will be good if you set max tokens to low for saving tokens. :param lang: The supported language is en, ko or ja. :return: False if the row generation_gt is a "don't know" meaning. """assert"generation_gt"inrow.keys(),"generation_gt column is not in the row."system_prompt:List[ChatMessage]=FILTER_PROMPT["dontknow_filter"][lang]results=[]forgen_gtinrow["generation_gt"]:response:ChatResponse=awaitllm.achat(messages=system_prompt+[ChatMessage(role=MessageRole.USER,content=gen_gt)])result_str=response.message.contentresults.append("true"inresult_str.lower().strip())returnnotany(results)