[docs]@generator_nodedefvllm(prompts:List[str],llm:str,**kwargs)->Tuple[List[str],List[List[int]],List[List[float]]]:""" Vllm module. It gets the VLLM instance, and returns generated texts by the input prompt. You can set logprobs to get the log probs of the generated text. Default logprobs is 1. :param prompts: A list of prompts. :param llm: Model name of vLLM. :param kwargs: The extra parameters for generating the text. :return: A tuple of three elements. The first element is a list of generated text. The second element is a list of generated text's token ids. The third element is a list of generated text's log probs. """try:fromvllm.outputsimportRequestOutputfromvllm.sequenceimportSampleLogprobsfromvllmimportSamplingParamsexceptImportError:raiseImportError("Please install vllm library. You can install it by running `pip install vllm`.")input_kwargs=deepcopy(kwargs)vllm_model=make_vllm_instance(llm,input_kwargs)if"logprobs"notininput_kwargs:input_kwargs["logprobs"]=1generate_params=SamplingParams(**input_kwargs)results:List[RequestOutput]=vllm_model.generate(prompts,generate_params)generated_texts=list(map(lambdax:x.outputs[0].text,results))generated_token_ids=list(map(lambdax:x.outputs[0].token_ids,results))log_probs:List[SampleLogprobs]=list(map(lambdax:x.outputs[0].logprobs,results))generated_log_probs=list(map(lambdax:list(map(lambday:y[0][y[1]].logprob,zip(x[0],x[1]))),zip(log_probs,generated_token_ids),))destroy_vllm_instance(vllm_model)returngenerated_texts,generated_token_ids,generated_log_probs
[docs]defmake_vllm_instance(llm:str,input_args:Dict):fromvllmimportLLMfromvllmimportSamplingParamsmodel_from_args=input_args.pop("model",None)model=llmifmodel_from_argsisNoneelsemodel_from_argsfrom_optional_params=inspect.signature(SamplingParams.from_optional).parameters.values()sampling_params_init_params=[param.nameforparaminfrom_optional_params]result_kwargs={}forkey,valueininput_args.items():ifkeynotinsampling_params_init_params:result_kwargs[key]=value# pop used result_kwargs keys in input_argsforkeyinresult_kwargs.keys():input_args.pop(key)returnLLM(model,**result_kwargs)