import logging
import os
import pathlib
import uuid
from typing import Dict, Optional, List, Union, Literal
import pandas as pd
from pyngrok import ngrok
from quart import Quart, request, jsonify
from quart.helpers import stream_with_context
from pydantic import BaseModel, ValidationError
from autorag.deploy.base import BaseRunner
from autorag.nodes.generator.base import BaseGenerator
from autorag.nodes.promptmaker.base import BasePromptMaker
from autorag.utils import fetch_contents
logger = logging.getLogger("AutoRAG")
deploy_dir = pathlib.Path(__file__).parent
root_dir = pathlib.Path(__file__).parent.parent
VERSION_PATH = os.path.join(root_dir, "VERSION")
[docs]
class QueryRequest(BaseModel):
query: str
result_column: Optional[str] = "generated_texts"
[docs]
class RetrievedPassage(BaseModel):
content: str
doc_id: str
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
end_idx: Optional[int] = None
[docs]
class RunResponse(BaseModel):
result: Union[str, List[str]]
retrieved_passage: List[RetrievedPassage]
[docs]
class Passage(BaseModel):
doc_id: str
content: str
score: float
[docs]
class RetrievalResponse(BaseModel):
passages: List[Passage]
[docs]
class StreamResponse(BaseModel):
"""
When the type is generated_text, only generated_text is returned. The other fields are None.
When the type is retrieved_passage, only retrieved_passage and passage_index are returned. The other fields are None.
"""
type: Literal["generated_text", "retrieved_passage"]
generated_text: Optional[str]
retrieved_passage: Optional[RetrievedPassage]
passage_index: Optional[int]
[docs]
class VersionResponse(BaseModel):
version: str
[docs]
class ApiRunner(BaseRunner):
def __init__(self, config: Dict, project_dir: Optional[str] = None):
super().__init__(config, project_dir)
self.app = Quart(__name__)
data_dir = os.path.join(project_dir, "data")
self.corpus_df = pd.read_parquet(
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
)
self.__add_api_route()
def __add_api_route(self):
@self.app.route("/v1/run", methods=["POST"])
async def run_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
# Simulate processing the query
generated_text = previous_result[data.result_column].tolist()[0]
retrieved_passage = self.extract_retrieve_passage(previous_result)
response = RunResponse(
result=generated_text, retrieved_passage=retrieved_passage
)
return jsonify(response.model_dump()), 200
@self.app.route("/v1/retrieve", methods=["POST"])
async def run_retrieve_only():
data = await request.get_json()
query = data.get("query", None)
if query is None:
return jsonify(
{
"error": "Invalid request. You need to include 'query' in the request body."
}
), 400
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if isinstance(module_instance, BasePromptMaker) or isinstance(
module_instance, BaseGenerator
):
continue
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(columns=duplicated_columns)
previous_result = pd.concat([drop_previous_result, new_result], axis=1)
# Simulate processing the query
retrieved_contents = previous_result["retrieved_contents"].tolist()[0]
retrieved_ids = previous_result["retrieved_ids"].tolist()[0]
retrieve_scores = previous_result["retrieve_scores"].tolist()[0]
retrieval_response = RetrievalResponse(
passages=[
Passage(doc_id=doc_id, content=content, score=score)
for doc_id, content, score in zip(
retrieved_ids, retrieved_contents, retrieve_scores
)
]
)
return jsonify(retrieval_response.model_dump()), 200
@self.app.route("/v1/stream", methods=["POST"])
async def stream_query():
try:
data = await request.get_json()
data = QueryRequest(**data)
except ValidationError as e:
return jsonify(e.errors()), 400
@stream_with_context
async def generate():
previous_result = pd.DataFrame(
{
"qid": str(uuid.uuid4()),
"query": [data.query],
"retrieval_gt": [[]],
"generation_gt": [""],
}
) # pseudo qa data for execution
for module_instance, module_param in zip(
self.module_instances, self.module_params
):
if not isinstance(module_instance, BaseGenerator):
new_result = module_instance.pure(
previous_result=previous_result, **module_param
)
duplicated_columns = previous_result.columns.intersection(
new_result.columns
)
drop_previous_result = previous_result.drop(
columns=duplicated_columns
)
previous_result = pd.concat(
[drop_previous_result, new_result], axis=1
)
else:
retrieved_passages = self.extract_retrieve_passage(
previous_result
)
for i, retrieved_passage in enumerate(retrieved_passages):
yield (
StreamResponse(
type="retrieved_passage",
generated_text=None,
retrieved_passage=retrieved_passage,
passage_index=i,
)
.model_dump_json()
.encode("utf-8")
)
# Start streaming of the result
assert len(previous_result) == 1
prompt: str = previous_result["prompts"].tolist()[0]
async for delta in module_instance.astream(
prompt=prompt, **module_param
):
response = StreamResponse(
type="generated_text",
generated_text=delta,
retrieved_passage=None,
passage_index=None,
)
yield response.model_dump_json().encode("utf-8")
return generate(), 200, {"X-Something": "value"}
@self.app.route("/version", methods=["GET"])
def get_version():
with open(VERSION_PATH, "r") as f:
version = f.read().strip()
response = VersionResponse(version=version)
return jsonify(response.model_dump()), 200
[docs]
def run_api_server(
self, host: str = "0.0.0.0", port: int = 8000, remote: bool = True, **kwargs
):
"""
Run the pipeline as an api server.
Here is api endpoint documentation => https://docs.auto-rag.com/deploy/api_endpoint.html
:param host: The host of the api server.
:param port: The port of the api server.
:param remote: Whether to expose the api server to the public internet using ngrok.
:param kwargs: Other arguments for Flask app.run.
"""
logger.info(f"Run api server at {host}:{port}")
if remote:
http_tunnel = ngrok.connect(str(port), "http")
public_url = http_tunnel.public_url
logger.info(f"Public API URL: {public_url}")
self.app.run(host=host, port=port, **kwargs)