[docs]classLlamaIndexLLM(BaseGenerator):def__init__(self,project_dir:str,llm:str,batch:int=16,*args,**kwargs):""" Initialize the Llama Index LLM module. :param project_dir: The project directory. :param llm: A llama index LLM instance. :param batch: The batch size for llm. Set low if you face some errors. Default is 16. :param kwargs: The extra parameters for initializing the llm instance. """super().__init__(project_dir=project_dir,llm=llm)ifself.llmnotingenerator_models.keys():raiseValueError(f"{self.llm} is not a valid llm name. Please check the llm name.""You can check valid llm names from autorag.generator_models.")self.batch=batchllm_class=generator_models[self.llm]ifllm_class.class_name()in["HuggingFace_LLM","HuggingFaceInferenceAPI","TextGenerationInference",]:model_name=kwargs.pop("model",None)ifmodel_nameisnotNone:kwargs["model_name"]=model_nameelse:if"model_name"notinkwargs.keys():raiseValueError("`model` or `model_name` parameter must be provided for using huggingfacellm.")kwargs["tokenizer_name"]=kwargs["model_name"]self.llm_instance:BaseLLM=llm_class(**pop_params(llm_class.__init__,kwargs))def__del__(self):super().__del__()delself.llm_instance
def_pure(self,prompts:List[str],)->Tuple[List[str],List[List[int]],List[List[float]]]:""" Llama Index LLM module. It gets the LLM instance from llama index, and returns generated text by the input prompt. It does not generate the right log probs, but it returns the pseudo log probs, which are not meant to be used for other modules. :param prompts: A list of prompts. :return: A tuple of three elements. The first element is a list of a generated text. The second element is a list of generated text's token ids, used tokenizer is GPT2Tokenizer. The third element is a list of generated text's pseudo log probs. """tasks=[self.llm_instance.acomplete(prompt)forpromptinprompts]loop=get_event_loop()results=loop.run_until_complete(process_batch(tasks,batch_size=self.batch))generated_texts=list(map(lambdax:x.text,results))tokenizer=AutoTokenizer.from_pretrained("gpt2",use_fast=False)tokenized_ids=tokenizer(generated_texts).data["input_ids"]pseudo_log_probs=list(map(lambdax:[0.5]*len(x),tokenized_ids))returngenerated_texts,tokenized_ids,pseudo_log_probs