Source code for autorag.deploy.gradio

import logging
import uuid

import pandas as pd

from autorag.deploy.base import BaseRunner

import gradio as gr


logger = logging.getLogger("AutoRAG")


[docs] class GradioRunner(BaseRunner):
[docs] def run_web( self, server_name: str = "0.0.0.0", server_port: int = 7680, share: bool = False, **kwargs, ): """ Run web interface to interact pipeline. You can access the web interface at `http://server_name:server_port` in your browser :param server_name: The host of the web. Default is 0.0.0.0. :param server_port: The port of the web. Default is 7680. :param share: Whether to create a publicly shareable link. Default is False. :param kwargs: Other arguments for gr.ChatInterface.launch. """ logger.info(f"Run web interface at http://{server_name}:{server_port}") def get_response(message, _): return self.run(message) gr.ChatInterface( get_response, title="📚 AutoRAG", retry_btn=None, undo_btn=None ).launch( server_name=server_name, server_port=server_port, share=share, **kwargs )
[docs] def run(self, query: str, result_column: str = "generated_texts"): """ Run the pipeline with query. The loaded pipeline must start with a single query, so the first module of the pipeline must be `query_expansion` or `retrieval` module. :param query: The query of the user. :param result_column: The result column name for the answer. Default is `generated_texts`, which is the output of the `generation` module. :return: The result of the pipeline. """ previous_result = pd.DataFrame( { "qid": str(uuid.uuid4()), "query": [query], "retrieval_gt": [[]], "generation_gt": [""], } ) # pseudo qa data for execution for module_instance, module_param in zip( self.module_instances, self.module_params ): new_result = module_instance.pure( previous_result=previous_result, **module_param ) duplicated_columns = previous_result.columns.intersection( new_result.columns ) drop_previous_result = previous_result.drop(columns=duplicated_columns) previous_result = pd.concat([drop_previous_result, new_result], axis=1) return previous_result[result_column].tolist()[0]