import logging
import os
import pathlib
import uuid
from typing import Dict, Optional, List, Union
import pandas as pd
from quart import Quart, request, jsonify
from quart.helpers import stream_with_context
from pydantic import BaseModel, ValidationError
from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.utils import fetch_contents
logger = logging.getLogger("AutoRAG")
deploy_dir = pathlib.Path(__file__).parent
root_dir = pathlib.Path(__file__).parent.parent
VERSION_PATH = os.path.join(root_dir, "VERSION")
[docs]
class QueryRequest(BaseModel):
query: str
result_column: Optional[str] = "generated_texts"
[docs]
class RetrievedPassage(BaseModel):
content: str
doc_id: str
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
end_idx: Optional[int] = None
[docs]
class RunResponse(BaseModel):
result: Union[str, List[str]]
retrieved_passage: List[RetrievedPassage]
[docs]
class VersionResponse(BaseModel):
version: str
empty_retrieved_passage = RetrievedPassage(
content="", doc_id="", filepath=None, file_page=None, start_idx=None, end_idx=None
)
[docs]
class ApiRunner(BaseRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
self.app = Quart(__name__)
data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
self.__add_api_route()
def __add_api_route(self):
@self.app.route("/v1/run", methods=["POST"])
async def run_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
# Simulate processing the query
generated_text = previous_result[data.result_column].tolist()[0]
retrieved_passage = self.extract_retrieve_passage(previous_result)
response = RunResponse(
result=generated_text, retrieved_passage=retrieved_passage
)
return jsonify(response.model_dump()), 200
@self.app.route("/v1/stream", methods=["POST"])
async def stream_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
@stream_with_context
async def generate():
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
retrieved_passages = self.extract_retrieve_passage(
previous_result
)
response = RunResponse(
result="", retrieved_passage=retrieved_passages
)
yield response.model_dump_json().encode("utf-8")
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
async for delta in module_instance.astream(
prompt=prompt, **module_param
):
response = RunResponse(
result=delta,
retrieved_passage=[empty_retrieved_passage],
)
yield response.model_dump_json().encode("utf-8")
return generate(), 200, {"X-Something": "value"}
@self.app.route("/version", methods=["GET"])
def get_version():
with open(VERSION_PATH, "r") as f:
version = f.read().strip()
response = VersionResponse(version=version)
return jsonify(response.model_dump()), 200
[docs]
def run_api_server(self, host: str = "0.0.0.0", port: int = 8000, **kwargs):
"""
Run the pipeline as api server.
You can send POST request to `http://host:port/run` with json body like below:
.. Code:: json
{
"query": "your query",
"result_column": "generated_texts"
}
And it returns json response like below:
.. Code:: json
{
"answer": "your answer"
}
:param host: The host of the api server.
:param port: The port of the api server.
:param kwargs: Other arguments for Flask app.run.
"""
logger.info(f"Run api server at {host}:{port}")
self.app.run(host=host, port=port, **kwargs)