Source code for autorag.nodes.generator.openai_llm

import logging
import os
from typing import List, Tuple

import tiktoken
from openai import AsyncOpenAI
from tiktoken import Encoding

from autorag.nodes.generator.base import generator_node
from autorag.utils.util import get_event_loop, process_batch

logger = logging.getLogger("AutoRAG")

MAX_TOKEN_DICT = {  # model name : token limit
	"o1-preview": 128_000,
	"o1-preview-2024-09-12": 128_000,
	"o1-mini": 128_000,
	"o1-mini-2024-09-12": 128_000,
	"gpt-4o-mini": 128_000,
	"gpt-4o-mini-2024-07-18": 128_000,
	"gpt-4o": 128_000,
	"gpt-4o-2024-08-06": 128_000,
	"gpt-4o-2024-05-13": 128_000,
	"chatgpt-4o-latest": 128_000,
	"gpt-4-turbo": 128_000,
	"gpt-4-turbo-2024-04-09": 128_000,
	"gpt-4-turbo-preview": 128_000,
	"gpt-4-0125-preview": 128_000,
	"gpt-4-1106-preview": 128_000,
	"gpt-4-vision-preview": 128_000,
	"gpt-4-1106-vision-preview": 128_000,
	"gpt-4": 8_192,
	"gpt-4-0613": 8_192,
	"gpt-4-32k": 32_768,
	"gpt-4-32k-0613": 32_768,
	"gpt-3.5-turbo-0125": 16_385,
	"gpt-3.5-turbo": 16_385,
	"gpt-3.5-turbo-1106": 16_385,
	"gpt-3.5-turbo-instruct": 4_096,
	"gpt-3.5-turbo-16k": 16_385,
	"gpt-3.5-turbo-0613": 4_096,
	"gpt-3.5-turbo-16k-0613": 16_385,
}


[docs] @generator_node def openai_llm( prompts: List[str], llm: str = "gpt-3.5-turbo", batch: int = 16, truncate: bool = True, api_key: str = None, **kwargs, ) -> Tuple[List[str], List[List[int]], List[List[float]]]: """ OpenAI generator module. Uses official openai library for generating answer from the given prompt. It returns real token ids and log probs, so you must use this for using token ids and log probs. :param prompts: A list of prompts. :param llm: A model name for openai. Default is gpt-3.5-turbo. :param batch: Batch size for openai api call. If you get API limit errors, you should lower the batch size. Default is 16. :param truncate: Whether to truncate the input prompt. Default is True. :param api_key: OpenAI API key. You can set this by passing env variable `OPENAI_API_KEY` :param kwargs: The optional parameter for openai api call `openai.chat.completion` See https://platform.openai.com/docs/api-reference/chat/create for more details. :return: A tuple of three elements. The first element is a list of generated text. The second element is a list of generated text's token ids. The third element is a list of generated text's log probs. """ if api_key is None: api_key = os.getenv("OPENAI_API_KEY") if api_key is None: raise ValueError( "OPENAI_API_KEY does not set. " "Please set env variable OPENAI_API_KEY or pass api_key parameter to openai module." ) if kwargs.get("logprobs") is not None: kwargs.pop("logprobs") logger.warning("parameter logprob does not effective. It always set to True.") if kwargs.get("n") is not None: kwargs.pop("n") logger.warning("parameter n does not effective. It always set to 1.") # TODO: fix this after updating tiktoken for the o1 model. It is not yet supported yet. if llm.startswith("o1"): tokenizer = tiktoken.get_encoding("o200k_base") else: tokenizer = tiktoken.encoding_for_model(llm) if truncate: max_token_size = MAX_TOKEN_DICT.get(llm) - 7 # because of chat token usage if max_token_size is None: raise ValueError( f"Model {llm} does not supported. " f"Please select the model between {list(MAX_TOKEN_DICT.keys())}" ) prompts = list( map( lambda prompt: truncate_by_token(prompt, tokenizer, max_token_size), prompts, ) ) client = AsyncOpenAI(api_key=api_key) loop = get_event_loop() if llm.startswith("o1"): tasks = [ get_result_o1(prompt, client, llm, tokenizer, **kwargs) for prompt in prompts ] else: tasks = [ get_result(prompt, client, llm, tokenizer, **kwargs) for prompt in prompts ] result = loop.run_until_complete(process_batch(tasks, batch)) answer_result = list(map(lambda x: x[0], result)) token_result = list(map(lambda x: x[1], result)) logprob_result = list(map(lambda x: x[2], result)) return answer_result, token_result, logprob_result
[docs] async def get_result( prompt: str, client: AsyncOpenAI, model: str, tokenizer: Encoding, **kwargs ): response = await client.chat.completions.create( model=model, messages=[ {"role": "user", "content": prompt}, ], logprobs=True, n=1, **kwargs, ) choice = response.choices[0] answer = choice.message.content logprobs = list(map(lambda x: x.logprob, choice.logprobs.content)) tokens = list( map( lambda x: tokenizer.encode(x.token, allowed_special="all")[0], choice.logprobs.content, ) ) assert len(tokens) == len(logprobs), "tokens and logprobs size is different." return answer, tokens, logprobs
[docs] async def get_result_o1( prompt: str, client: AsyncOpenAI, model: str, tokenizer: Encoding, **kwargs ): assert model.startswith("o1"), "This function only supports o1 model." # The default temperature for the o1 model is 1. 1 is only supported. # See https://platform.openai.com/docs/guides/reasoning about beta limitation of o1 models. kwargs["temperature"] = 1 kwargs["top_p"] = 1 kwargs["presence_penalty"] = 0 kwargs["frequency_penalty"] = 0 response = await client.chat.completions.create( model=model, messages=[ {"role": "user", "content": prompt}, ], logprobs=False, n=1, **kwargs, ) answer = response.choices[0].message.content tokens = tokenizer.encode(answer, allowed_special="all") pseudo_log_probs = [0.5] * len(tokens) return answer, tokens, pseudo_log_probs
[docs] def truncate_by_token(prompt: str, tokenizer: Encoding, max_token_size: int): tokens = tokenizer.encode(prompt, allowed_special="all") return tokenizer.decode(tokens[:max_token_size])