Source code for autorag.deploy.api

import logging
import os
import pathlib
import uuid
from typing import Dict, Optional, List, Union, Literal

import pandas as pd
from quart import Quart, request, jsonify
from quart.helpers import stream_with_context
from pydantic import BaseModel, ValidationError

from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.nodes.promptmaker.base import BasePromptMaker
from autorag.utils.util import fetch_contents, to_list

logger = logging.getLogger("AutoRAG")

deploy_dir = pathlib.Path(__file__).parent
root_dir = pathlib.Path(__file__).parent.parent

VERSION_PATH = os.path.join(root_dir, "VERSION")


[docs] class QueryRequest(BaseModel): query: str result_column: Optional[str] = "generated_texts"
[docs] class RetrievedPassage(BaseModel): content: str doc_id: str score: float filepath: Optional[str] = None file_page: Optional[int] = None start_idx: Optional[int] = None end_idx: Optional[int] = None
[docs] class RunResponse(BaseModel): result: Union[str, List[str]] retrieved_passage: List[RetrievedPassage]
[docs] class RetrievalResponse(BaseModel): passages: List[RetrievedPassage]
[docs] class StreamResponse(BaseModel): """ When the type is generated_text, only generated_text is returned. The other fields are None. When the type is retrieved_passage, only retrieved_passage and passage_index are returned. The other fields are None. """ type: Literal["generated_text", "retrieved_passage"] generated_text: Optional[str] retrieved_passage: Optional[RetrievedPassage] passage_index: Optional[int]
[docs] class VersionResponse(BaseModel): version: str
[docs] class ApiRunner(BaseRunner): def __init__(self, config: Dict, project_dir: Optional[str] = None): super().__init__(config, project_dir) self.app = Quart(__name__) data_dir = os.path.join(project_dir, "data") self.corpus_df = pd.read_parquet( os.path.join(data_dir, "corpus.parquet"), engine="pyarrow" ) self.__add_api_route() def __add_api_route(self): @self.app.route("/v1/run", methods=["POST"]) async def run_query(): try: data = await request.get_json() data = QueryRequest(**data) except ValidationError as e: return jsonify(e.errors()), 400 previous_result = pd.DataFrame( { "qid": str(uuid.uuid4()), "query": [data.query], "retrieval_gt": [[]], "generation_gt": [""], } ) # pseudo qa data for execution for module_instance, module_param in zip( self.module_instances, self.module_params ): new_result = module_instance.pure( previous_result=previous_result, **module_param ) duplicated_columns = previous_result.columns.intersection( new_result.columns ) drop_previous_result = previous_result.drop(columns=duplicated_columns) previous_result = pd.concat([drop_previous_result, new_result], axis=1) # Simulate processing the query generated_text = previous_result[data.result_column].tolist()[0] retrieved_passage = self.extract_retrieve_passage(previous_result) response = RunResponse( result=generated_text, retrieved_passage=retrieved_passage ) return jsonify(response.model_dump()), 200 @self.app.route("/v1/retrieve", methods=["POST"]) async def run_retrieve_only(): data = await request.get_json() query = data.get("query", None) if query is None: return jsonify( { "error": "Invalid request. You need to include 'query' in the request body." } ), 400 previous_result = pd.DataFrame( { "qid": str(uuid.uuid4()), "query": [query], "retrieval_gt": [[]], "generation_gt": [""], } ) # pseudo qa data for execution for module_instance, module_param in zip( self.module_instances, self.module_params ): if isinstance(module_instance, BasePromptMaker) or isinstance( module_instance, BaseGenerator ): continue new_result = module_instance.pure( previous_result=previous_result, **module_param ) duplicated_columns = previous_result.columns.intersection( new_result.columns ) drop_previous_result = previous_result.drop(columns=duplicated_columns) previous_result = pd.concat([drop_previous_result, new_result], axis=1) # Simulate processing the query retrieved_passages = self.extract_retrieve_passage(previous_result) retrieval_response = RetrievalResponse(passages=retrieved_passages) return jsonify(retrieval_response.model_dump()), 200 @self.app.route("/v1/stream", methods=["POST"]) async def stream_query(): try: data = await request.get_json() data = QueryRequest(**data) except ValidationError as e: return jsonify(e.errors()), 400 @stream_with_context async def generate(): previous_result = pd.DataFrame( { "qid": str(uuid.uuid4()), "query": [data.query], "retrieval_gt": [[]], "generation_gt": [""], } ) # pseudo qa data for execution for module_instance, module_param in zip( self.module_instances, self.module_params ): if not isinstance(module_instance, BaseGenerator): new_result = module_instance.pure( previous_result=previous_result, **module_param ) duplicated_columns = previous_result.columns.intersection( new_result.columns ) drop_previous_result = previous_result.drop( columns=duplicated_columns ) previous_result = pd.concat( [drop_previous_result, new_result], axis=1 ) else: retrieved_passages = self.extract_retrieve_passage( previous_result ) for i, retrieved_passage in enumerate(retrieved_passages): yield ( StreamResponse( type="retrieved_passage", generated_text=None, retrieved_passage=retrieved_passage, passage_index=i, ) .model_dump_json() .encode("utf-8") ) # Start streaming of the result assert len(previous_result) == 1 prompt: str = previous_result["prompts"].tolist()[0] async for delta in module_instance.astream( prompt=prompt, **module_param ): response = StreamResponse( type="generated_text", generated_text=delta, retrieved_passage=None, passage_index=None, ) yield response.model_dump_json().encode("utf-8") return generate(), 200, {"X-Something": "value"} @self.app.route("/version", methods=["GET"]) def get_version(): with open(VERSION_PATH, "r") as f: version = f.read().strip() response = VersionResponse(version=version) return jsonify(response.model_dump()), 200
[docs] def run_api_server( self, host: str = "0.0.0.0", port: int = 8000, remote: bool = True, **kwargs ): """ Run the pipeline as an api server. Here is api endpoint documentation => https://docs.auto-rag.com/deploy/api_endpoint.html :param host: The host of the api server. :param port: The port of the api server. :param remote: Whether to expose the api server to the public internet using ngrok. :param kwargs: Other arguments for Flask app.run. """ logger.info(f"Run api server at {host}:{port}") if remote: from pyngrok import ngrok http_tunnel = ngrok.connect(str(port), "http") public_url = http_tunnel.public_url logger.info(f"Public API URL: {public_url}") self.app.run(host=host, port=port, **kwargs)
[docs] def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]: retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0] contents = fetch_contents(self.corpus_df, [retrieved_ids])[0] scores = df["retrieve_scores"].tolist()[0] if "path" in self.corpus_df.columns: paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[ 0 ] else: paths = [None] * len(retrieved_ids) metadatas = fetch_contents( self.corpus_df, [retrieved_ids], column_name="metadata" )[0] if "start_end_idx" in self.corpus_df.columns: start_end_indices = fetch_contents( self.corpus_df, [retrieved_ids], column_name="start_end_idx" )[0] else: start_end_indices = [None] * len(retrieved_ids) start_end_indices = to_list(start_end_indices) return list( map( lambda content, doc_id, score, path, metadata, start_end_idx: RetrievedPassage( content=content, doc_id=doc_id, score=score, filepath=path, file_page=metadata.get("page", None), start_idx=start_end_idx[0] if start_end_idx else None, end_idx=start_end_idx[1] if start_end_idx else None, ), contents, retrieved_ids, scores, paths, metadatas, start_end_indices, ) )