[docs]classVllm(BaseGenerator):def__init__(self,project_dir:str,llm:str,**kwargs):super().__init__(project_dir,llm,**kwargs)try:fromvllmimportSamplingParams,LLMexceptImportError:raiseImportError("Please install vllm library. You can install it by running `pip install vllm`.")model_from_kwargs=kwargs.pop("model",None)model=llmifmodel_from_kwargsisNoneelsemodel_from_kwargsinput_kwargs=deepcopy(kwargs)sampling_params_init_params=pop_params(SamplingParams.from_optional,input_kwargs)self.vllm_model=LLM(model,**input_kwargs)# delete not sampling param keys in the kwargskwargs_keys=list(kwargs.keys())forkeyinkwargs_keys:ifkeynotinsampling_params_init_params:kwargs.pop(key)def__del__(self):try:importtorchimportcontextlibiftorch.cuda.is_available():fromvllm.distributed.parallel_stateimport(destroy_model_parallel,destroy_distributed_environment,)destroy_model_parallel()destroy_distributed_environment()delself.vllm_model.llm_engine.model_executordelself.vllm_modelwithcontextlib.suppress(AssertionError):torch.distributed.destroy_process_group()gc.collect()torch.cuda.empty_cache()torch.cuda.synchronize()exceptImportError:delself.vllm_modelsuper().__del__()
def_pure(self,prompts:List[str],**kwargs)->Tuple[List[str],List[List[int]],List[List[float]]]:""" Vllm module. It gets the VLLM instance and returns generated texts by the input prompt. You can set logprobs to get the log probs of the generated text. Default logprobs is 1. :param prompts: A list of prompts. :param kwargs: The extra parameters for generating the text. :return: A tuple of three elements. The first element is a list of generated text. The second element is a list of generated text's token ids. The third element is a list of generated text's log probs. """try:fromvllm.outputsimportRequestOutputfromvllm.sequenceimportSampleLogprobsfromvllmimportSamplingParamsexceptImportError:raiseImportError("Please install vllm library. You can install it by running `pip install vllm`.")if"logprobs"notinkwargs:kwargs["logprobs"]=1sampling_params=pop_params(SamplingParams.from_optional,kwargs)generate_params=SamplingParams(**sampling_params)results:List[RequestOutput]=self.vllm_model.generate(prompts,generate_params)generated_texts=list(map(lambdax:x.outputs[0].text,results))generated_token_ids=list(map(lambdax:x.outputs[0].token_ids,results))log_probs:List[SampleLogprobs]=list(map(lambdax:x.outputs[0].logprobs,results))generated_log_probs=list(map(lambdax:list(map(lambday:y[0][y[1]].logprob,zip(x[0],x[1]))),zip(log_probs,generated_token_ids),))return(to_list(generated_texts),to_list(generated_token_ids),to_list(generated_log_probs),)