Skip to content

Chat Router

Bases: APIRouter

Class to deploy a LLM router with FastAPI.

Source code in libre_chat/router.py
class ChatRouter(APIRouter):
    """
    Class to deploy a LLM router with FastAPI.
    """

    def __init__(
        self,
        *args: Any,
        llm: Any,
        path: str = "/prompt",
        conf: Optional[ChatConf] = None,
        examples: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        """
        Constructor of the LLM API router with the actual calls
        """
        self.path = path
        self.llm = llm
        self.conf = conf if conf else default_conf
        self.title = self.conf.info.title
        self.description = self.conf.info.description
        self.version = self.conf.info.version
        self.examples = examples if examples else self.conf.info.examples
        example_post = {"prompt": self.examples[0]}

        # Instantiate APIRouter
        super().__init__(
            *args,
            responses=api_responses,
            **kwargs,
        )
        # Create a list to store all connected WebSocket clients
        self.connected_clients: List[WebSocket] = []

        @self.get(
            self.path,
            name="Prompt the LLM",
            description=self.description,
            response_model=PromptResponse,
        )
        def get_prompt(request: Request, prompt: str = self.examples[0]) -> JSONResponse:
            """Send a prompt to the chatbot through HTTP GET operation.

            :param request: The HTTP GET request with a .body()
            :param prompt: Prompt to send to the LLM
            """
            return JSONResponse(self.llm.query(prompt))

        @self.post(
            self.path,
            name="Prompt the LLM",
            description=self.description,
            response_description="Prompt response",
            response_model=PromptResponse,
        )
        def post_prompt(
            request: Request,
            prompt: Prompt = Body(..., example=example_post),
        ) -> JSONResponse:
            """Send a prompt to the chatbot through HTTP POST operation.

            :param request: The HTTP POST request with a .body()
            :param prompt: Prompt to send to the LLM.
            """
            return JSONResponse(self.llm.query(prompt.prompt))

        @self.post(
            "/documents",
            description="""Upload documents to be added to the vectorstore, you can provide a zip file that will be automatically unzipped.""",
            response_description="Operation result",
            response_model={},
            tags=["vectorstore"],
        )
        def upload_documents(
            files: List[UploadFile] = File(...),
            admin_pass: Optional[str] = None,
            # current_user: User = Depends(get_current_user),
        ) -> JSONResponse:
            os.makedirs(self.conf.vector.documents_path, exist_ok=True)
            if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
                raise HTTPException(
                    status_code=403,
                    detail="The admin pass key provided was wrong",
                )
            for uploaded in files:
                if uploaded.filename:  # no cov
                    file_path = os.path.join(self.conf.vector.documents_path, uploaded.filename)
                    with open(file_path, "wb") as file:
                        file.write(uploaded.file.read())
                    # Check if the uploaded file is a zip file
                    if uploaded.filename.endswith(".zip"):
                        log.info(f"🤐 Unzipping {file_path}")
                        with zipfile.ZipFile(file_path, "r") as zip_ref:
                            zip_ref.extractall(self.conf.vector.documents_path)
                        os.remove(file_path)
            self.llm.build_vectorstore()
            self.llm.setup_dbqa()
            return JSONResponse(
                {
                    "message": f"Documents uploaded in {self.conf.vector.documents_path}, vectorstore rebuilt."
                }
            )

        @self.get(
            "/documents",
            description="""List documents uploaded to the server.""",
            response_description="List of files",
            response_model={},
            tags=["vectorstore"],
        )
        def list_documents(
            admin_pass: Optional[str] = None,
            # Depends(get_current_user)
        ) -> JSONResponse:
            """List all documents in the documents folder."""
            if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
                raise HTTPException(
                    status_code=403,
                    detail="The admin pass key provided was wrong",
                )
            file_list = os.listdir(self.conf.vector.documents_path)
            return JSONResponse({"count": len(file_list), "files": file_list})

        @self.get(
            "/config",
            name="Get Chat configuration",
            description="""Get the Chat web service configuration.""",
            response_description="Chat configuration",
            response_model=ChatConf,
            tags=["configuration"],
        )
        def get_config(
            admin_pass: Optional[str] = None,
        ) -> JSONResponse:
            """Get the Chat web service configuration."""
            if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
                raise HTTPException(
                    status_code=403,
                    detail="The admin pass key provided was wrong",
                )
            return JSONResponse(self.conf.dict())

        @self.post(
            "/config",
            name="Edit Chat configuration",
            description="""Edit the Chat web service configuration.""",
            response_description="Chat configuration",
            response_model=ChatConf,
            tags=["configuration"],
        )
        def post_config(
            request: Request,
            config: ChatConf = Body(..., example=self.conf),
            admin_pass: Optional[str] = None,
        ) -> JSONResponse:
            """Edit the Chat web service configuration."""
            if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
                raise HTTPException(
                    status_code=403,
                    detail="The admin pass key provided was wrong",
                )
            self.conf = config
            # TODO: save new config to disk, and make sure all workers reload the new config
            return JSONResponse(self.conf.dict())

        @self.websocket("/chat")
        async def websocket_endpoint(websocket: WebSocket) -> None:
            await websocket.accept()
            self.connected_clients.append(websocket)
            log.info(
                f"🔌 New websocket connection: {len(self.connected_clients)} clients are connected"
            )
            memory = ConversationBufferMemory(ai_prefix="AI Assistant")
            try:
                # Loop to receive messages from the WebSocket client
                while True:
                    data = await websocket.receive_json()

                    start_resp = ChatResponse(sender="bot", message="", type="start")
                    await websocket.send_json(start_resp.dict())

                    resp = await self.llm.aquery(
                        data["prompt"],
                        memory=memory,
                        callbacks=[StreamWebsocketCallback(websocket)],
                    )
                    # chat_history.append((question, resp["result"]))
                    # log.warning("RESULTS!")
                    # log.warning(resp["result"])

                    end_resp = ChatResponse(
                        sender="bot",
                        message=resp["result"],
                        type="end",
                        sources=resp["source_documents"] if "source_documents" in resp else None,
                    )
                    await websocket.send_json(end_resp.model_dump())
            except Exception as e:
                log.error(f"WebSocket error: {e}")
            finally:
                self.connected_clients.remove(websocket)

__init__(*args, llm, path='/prompt', conf=None, examples=None, **kwargs)

Constructor of the LLM API router with the actual calls

Source code in libre_chat/router.py
def __init__(
    self,
    *args: Any,
    llm: Any,
    path: str = "/prompt",
    conf: Optional[ChatConf] = None,
    examples: Optional[List[str]] = None,
    **kwargs: Any,
) -> None:
    """
    Constructor of the LLM API router with the actual calls
    """
    self.path = path
    self.llm = llm
    self.conf = conf if conf else default_conf
    self.title = self.conf.info.title
    self.description = self.conf.info.description
    self.version = self.conf.info.version
    self.examples = examples if examples else self.conf.info.examples
    example_post = {"prompt": self.examples[0]}

    # Instantiate APIRouter
    super().__init__(
        *args,
        responses=api_responses,
        **kwargs,
    )
    # Create a list to store all connected WebSocket clients
    self.connected_clients: List[WebSocket] = []

    @self.get(
        self.path,
        name="Prompt the LLM",
        description=self.description,
        response_model=PromptResponse,
    )
    def get_prompt(request: Request, prompt: str = self.examples[0]) -> JSONResponse:
        """Send a prompt to the chatbot through HTTP GET operation.

        :param request: The HTTP GET request with a .body()
        :param prompt: Prompt to send to the LLM
        """
        return JSONResponse(self.llm.query(prompt))

    @self.post(
        self.path,
        name="Prompt the LLM",
        description=self.description,
        response_description="Prompt response",
        response_model=PromptResponse,
    )
    def post_prompt(
        request: Request,
        prompt: Prompt = Body(..., example=example_post),
    ) -> JSONResponse:
        """Send a prompt to the chatbot through HTTP POST operation.

        :param request: The HTTP POST request with a .body()
        :param prompt: Prompt to send to the LLM.
        """
        return JSONResponse(self.llm.query(prompt.prompt))

    @self.post(
        "/documents",
        description="""Upload documents to be added to the vectorstore, you can provide a zip file that will be automatically unzipped.""",
        response_description="Operation result",
        response_model={},
        tags=["vectorstore"],
    )
    def upload_documents(
        files: List[UploadFile] = File(...),
        admin_pass: Optional[str] = None,
        # current_user: User = Depends(get_current_user),
    ) -> JSONResponse:
        os.makedirs(self.conf.vector.documents_path, exist_ok=True)
        if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
            raise HTTPException(
                status_code=403,
                detail="The admin pass key provided was wrong",
            )
        for uploaded in files:
            if uploaded.filename:  # no cov
                file_path = os.path.join(self.conf.vector.documents_path, uploaded.filename)
                with open(file_path, "wb") as file:
                    file.write(uploaded.file.read())
                # Check if the uploaded file is a zip file
                if uploaded.filename.endswith(".zip"):
                    log.info(f"🤐 Unzipping {file_path}")
                    with zipfile.ZipFile(file_path, "r") as zip_ref:
                        zip_ref.extractall(self.conf.vector.documents_path)
                    os.remove(file_path)
        self.llm.build_vectorstore()
        self.llm.setup_dbqa()
        return JSONResponse(
            {
                "message": f"Documents uploaded in {self.conf.vector.documents_path}, vectorstore rebuilt."
            }
        )

    @self.get(
        "/documents",
        description="""List documents uploaded to the server.""",
        response_description="List of files",
        response_model={},
        tags=["vectorstore"],
    )
    def list_documents(
        admin_pass: Optional[str] = None,
        # Depends(get_current_user)
    ) -> JSONResponse:
        """List all documents in the documents folder."""
        if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
            raise HTTPException(
                status_code=403,
                detail="The admin pass key provided was wrong",
            )
        file_list = os.listdir(self.conf.vector.documents_path)
        return JSONResponse({"count": len(file_list), "files": file_list})

    @self.get(
        "/config",
        name="Get Chat configuration",
        description="""Get the Chat web service configuration.""",
        response_description="Chat configuration",
        response_model=ChatConf,
        tags=["configuration"],
    )
    def get_config(
        admin_pass: Optional[str] = None,
    ) -> JSONResponse:
        """Get the Chat web service configuration."""
        if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
            raise HTTPException(
                status_code=403,
                detail="The admin pass key provided was wrong",
            )
        return JSONResponse(self.conf.dict())

    @self.post(
        "/config",
        name="Edit Chat configuration",
        description="""Edit the Chat web service configuration.""",
        response_description="Chat configuration",
        response_model=ChatConf,
        tags=["configuration"],
    )
    def post_config(
        request: Request,
        config: ChatConf = Body(..., example=self.conf),
        admin_pass: Optional[str] = None,
    ) -> JSONResponse:
        """Edit the Chat web service configuration."""
        if self.conf.auth.admin_pass and admin_pass != self.conf.auth.admin_pass:
            raise HTTPException(
                status_code=403,
                detail="The admin pass key provided was wrong",
            )
        self.conf = config
        # TODO: save new config to disk, and make sure all workers reload the new config
        return JSONResponse(self.conf.dict())

    @self.websocket("/chat")
    async def websocket_endpoint(websocket: WebSocket) -> None:
        await websocket.accept()
        self.connected_clients.append(websocket)
        log.info(
            f"🔌 New websocket connection: {len(self.connected_clients)} clients are connected"
        )
        memory = ConversationBufferMemory(ai_prefix="AI Assistant")
        try:
            # Loop to receive messages from the WebSocket client
            while True:
                data = await websocket.receive_json()

                start_resp = ChatResponse(sender="bot", message="", type="start")
                await websocket.send_json(start_resp.dict())

                resp = await self.llm.aquery(
                    data["prompt"],
                    memory=memory,
                    callbacks=[StreamWebsocketCallback(websocket)],
                )
                # chat_history.append((question, resp["result"]))
                # log.warning("RESULTS!")
                # log.warning(resp["result"])

                end_resp = ChatResponse(
                    sender="bot",
                    message=resp["result"],
                    type="end",
                    sources=resp["source_documents"] if "source_documents" in resp else None,
                )
                await websocket.send_json(end_resp.model_dump())
        except Exception as e:
            log.error(f"WebSocket error: {e}")
        finally:
            self.connected_clients.remove(websocket)