Start the FlexAI endpoints
Create the FlexAI secret that contains your HF token in order to access the inference models:# Enter your HF token value when prompted
flexai secret create hf-token
LLM_INFERENCE_NAME=qwen-llm
export LLM_MODEL_NAME=Qwen/Qwen3-30B-A3B-Instruct-2507
flexai inference serve $LLM_INFERENCE_NAME --hf-token-secret hf-token -- --model=$LLM_MODEL_NAME --enable-auto-tool-choice --tool-call-parser hermes --max-model-len 16384
# store the returned information
export LLM_API_KEY=<store the given API key>
export LLM_URL=$(flexai inference inspect $LLM_INFERENCE_NAME -j | jq .config.endpointUrl -r)
EMBED_INFERENCE_NAME=e5-embed
export EMBEDDINGS_MODEL_NAME=intfloat/multilingual-e5-large
flexai inference serve $EMBED_INFERENCE_NAME --runtime vllm-nvidia-0.9.1 --hf-token-secret hf-token -- --model=$EMBEDDINGS_MODEL_NAME --task=embed --trust-remote-code --dtype=float32
# store the returned information
export EMBEDDINGS_API_KEY=<store the given API key>
export EMBEDDINGS_URL=$(flexai inference inspect $EMBED_INFERENCE_NAME -j | jq .config.endpointUrl -r)
LangSmith (Optional)
export LANGSMITH_TRACING="true"
export LANGSMITH_API_KEY="..."
export LANGSMITH_PROJECT=rag-with-flexai
Setup
The code of this experiment is located atcode/rag and the following commands should be run from this location.
Using Docker
Run the Docker container
docker run -p 7860:7860 \
-e LLM_MODEL_NAME=$LLM_MODEL_NAME \
-e LLM_API_KEY=$LLM_API_KEY \
-e LLM_URL=$LLM_URL \
-e EMBEDDINGS_MODEL_NAME=$EMBEDDINGS_MODEL_NAME \
-e EMBEDDINGS_API_KEY=$EMBEDDINGS_API_KEY \
-e EMBEDDINGS_URL=$EMBEDDINGS_URL \
-e LANGSMITH_TRACING=$LANGSMITH_TRACING \
-e LANGSMITH_API_KEY=$LANGSMITH_API_KEY \
-e LANGSMITH_PROJECT=$LANGSMITH_PROJECT \
rag-application
Local Setup
Usage
Once the application is running, you can access the Gradio interface in your web browser athttp://localhost:7860. You can upload documents and ask questions based on the content of those documents.
For examples, you can upload documents located at code/rag/data and ask questions such as
- What is this demo about?
- For which workflows LLM agents are useful?
- Where can I find bioluminescent fungis?
Code
requirements.txt
beautifulsoup4>=4.13.4
docx2txt>=0.9
gradio==5.49.1
langchain-community==0.3.30
langchain-openai==0.3.35
langchain-text-splitters==0.3.11
langgraph==0.6.8
lxml>=5.4.0
pypdf>=5.4.0
transformers>=4.43.3
run_rag.py
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT
import logging
import os
from uuid import uuid4
import gradio as gr
from src.rag_pipeline import RagPipeline
current_dir = os.path.dirname(os.path.abspath(__file__))
logo_path = os.path.join(current_dir, "logo.png")
rag_pipeline = RagPipeline(chunk_size=500, chunk_overlap=50, use_tools=False)
def get_endpoint_config():
config = rag_pipeline.get_endpoint_config()
return [
config.get("llm_model_name", ""),
config.get("llm_api_key", ""),
config.get("llm_url", ""),
config.get("embeddings_model_name", ""),
config.get("embeddings_api_key", ""),
config.get("embeddings_url", ""),
]
def set_endpoint_config(
llm_model_name,
llm_api_key,
llm_url,
embeddings_model_name,
embeddings_api_key,
embeddings_url,
):
config = {
"llm_model_name": llm_model_name,
"llm_api_key": llm_api_key,
"llm_url": llm_url,
"embeddings_model_name": embeddings_model_name,
"embeddings_api_key": embeddings_api_key,
"embeddings_url": embeddings_url,
}
try:
rag_pipeline.set_endpoint_config(config)
error_msg = ""
error_visible = False
except Exception as e:
error_msg = f"<span style='color:red; font-weight:bold;'>Error: {e}</span>"
error_visible = True
config_values = get_endpoint_config()
doc_list = clear_document_list()
return (
*config_values,
doc_list,
gr.update(value=error_msg, visible=error_visible),
)
def toggle_api_key_visibility(visible, value):
type = "text" if visible else "password"
return gr.Textbox(label="API Key", type=type, value=value)
def clear_history():
new_id = uuid4()
logging.info(f"New thread_id: {new_id}")
return str(new_id), None
def clear_document_list():
rag_pipeline.clear_vector_store()
return []
def add_message(history, message, thread_id):
if message:
history.append({"role": "user", "content": message})
try:
response = rag_pipeline.query(message, thread_id=thread_id)
bot_message = response["messages"][-1].content
except Exception as e:
bot_message = (
f"<span style='color:red; font-weight:bold;'>Error: {e}</span>"
)
history.append({"role": "assistant", "content": bot_message})
return history, ""
def on_files_uploaded(history, files):
error_msgs = []
for f in files:
try:
rag_pipeline.add_documents([f])
except Exception as e:
filename = f.split("/")[-1]
error_msgs.append(f"File '{filename}': {e}")
if error_msgs:
error_msg = "<br>".join(
f"<span style='color:red; font-weight:bold;'>Error: {msg}</span>"
for msg in error_msgs
)
history.append({"role": "assistant", "content": error_msg})
return history
def on_files_deleted():
rag_pipeline.clear_vector_store()
css = """
:root {
color-scheme: light dark;
}
footer { visibility: hidden; }
#chatbot {
background: light-dark(#f7f7fa, #000) !important;
}
#main-card {
background: #fff;
}
.file-preview {
min-height: 240px;
}
.user {
background: light-dark(#fff, #000);
border-color: light-dark(#fff, #000);
}
.bot {
background: light-dark(#f7f7fa, #18181b);
border-color: light-dark(#f7f7fa, #18181b);
box-shadow: 0px 0px 0px 0px !important;
}
#chat-area {
background: light-dark(#f7f7fa, #000);
}
#chat-input {
background: #fff;
}
#send-btn {
background: #fff;
}
#filetable { button { color: black; background-color: white; } }
.tab-content {
background: #fff;
}
"""
with gr.Blocks(
title="FlexBot: Ask me anything!",
fill_height=True,
css=css,
theme=gr.themes.Default(primary_hue=gr.themes.colors.blue),
) as demo:
uuid_state = gr.State(value=str(uuid4()))
doc_list = gr.State(value=[])
title = f"""
<div style='display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px;'>
<div>
<h1 style='margin-bottom: 0;'>Search RAG</h1>
<h2 style='margin: 0; font-weight: 400;'>FlexBot can answer questions based on the provided documents.</h2>
</div>
<img src='/gradio_api/file={logo_path}' alt='Logo' style='height: 80px; margin-left: 20px;'>
</div>
"""
gr.HTML(title)
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_id="main-card"):
with gr.Group(elem_id="chat-area"):
with gr.Row():
with gr.Column(scale=1):
chatbot = gr.Chatbot(
height="70dvh",
show_label=False,
elem_id="chatbot",
type="messages",
scale=1,
)
chat_input = gr.Textbox(
interactive=True,
placeholder="Ask me anything!",
show_label=False,
submit_btn=True,
)
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Documents"):
with gr.Group(elem_id="tab-content"):
files_new = gr.File(
file_count="multiple",
interactive=True,
scale=1,
file_types=[
".txt",
".pdf",
".csv",
".doc",
".docx",
".html",
".htm",
],
)
with gr.TabItem("Endpoint Config."):
with gr.Group(elem_id="tab-content"):
llm_model_name = gr.Textbox(label="LLM Model Name")
llm_url = gr.Textbox(label="LLM URL")
with gr.Row():
llm_api_key = gr.Textbox(
placeholder="LLM API Key",
type="password",
scale=4,
show_label=False,
)
llm_api_key_visible = gr.Checkbox(
label="Show LLM API Key",
value=False,
scale=1,
)
embeddings_model_name = gr.Textbox(
label="Embeddings Model Name"
)
embeddings_url = gr.Textbox(label="Embeddings URL")
with gr.Row():
embeddings_api_key = gr.Textbox(
label="Embeddings API Key",
type="password",
scale=4,
show_label=False,
)
embeddings_api_key_visible = gr.Checkbox(
label="Show Embeddings API Key",
value=False,
scale=1,
)
save_btn = gr.Button("Save")
config_error = gr.Markdown(value="", visible=False)
# Prefill on load
demo.load(
get_endpoint_config,
inputs=None,
outputs=[
llm_model_name,
llm_api_key,
llm_url,
embeddings_model_name,
embeddings_api_key,
embeddings_url,
],
)
demo.load(
clear_document_list,
inputs=None,
outputs=[doc_list],
)
demo.load(
clear_history,
inputs=None,
outputs=[uuid_state, chatbot],
)
save_btn.click(
set_endpoint_config,
inputs=[
llm_model_name,
llm_api_key,
llm_url,
embeddings_model_name,
embeddings_api_key,
embeddings_url,
],
outputs=[
llm_model_name,
llm_api_key,
llm_url,
embeddings_model_name,
embeddings_api_key,
embeddings_url,
doc_list,
config_error,
],
)
llm_api_key_visible.change(
toggle_api_key_visibility,
inputs=[llm_api_key_visible, llm_api_key],
outputs=llm_api_key,
)
embeddings_api_key_visible.change(
toggle_api_key_visibility,
inputs=[embeddings_api_key_visible, embeddings_api_key],
outputs=embeddings_api_key,
)
chatbot.clear(clear_history, outputs=[uuid_state, chatbot])
chat_msg = chat_input.submit(
add_message, [chatbot, chat_input, uuid_state], [chatbot, chat_input]
).then(lambda: gr.Textbox(interactive=True), None, [chat_input])
files_new.upload(on_files_uploaded, inputs=[chatbot, files_new], outputs=[chatbot])
files_new.delete(on_files_deleted, inputs=None, outputs=None)
files_new.clear(on_files_deleted, inputs=None, outputs=None)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
demo.launch(allowed_paths=[logo_path], favicon_path=logo_path)
Dockerfile
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT
FROM python:3.11-slim
# Set the working directory
WORKDIR /app
# Copy the requirements file
COPY requirements.txt .
# Install the dependencies
RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
# Copy the application code
COPY src/ ./src/
COPY run_rag.py ./run_rag.py
COPY logo.png ./logo.png
# Expose the port the app runs on
EXPOSE 7860
ENV GRADIO_SERVER_NAME="0.0.0.0"
# Command to run the application
CMD ["python", "run_rag.py"]
src/rag_pipeline.py
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT
import logging
import os
import uuid
from typing import List
from langchain.schema import BaseMessage
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from src.reader import MixedFileTypeLoader
from transformers import AutoConfig
class RagPipeline:
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 50,
top_k: int = 5,
use_tools: bool = False,
):
"""
Initializes the pipeline and sets up models and vector store.
Args:
chunk_size (int, optional): The size of each chunk for processing. Defaults to 500.
chunk_overlap (int, optional): The overlap size between consecutive chunks. Defaults to 50.
top_k (int, optional): The number of top scored documents to retrieve. Defaults to 5.
use_tools (bool, optional): Flag to indicate whether to use LLM tool calling.
If True, the LLM will decide whether it needs to use tools to retrieve information or directly
respond to the query without invoking the document search.
If False, document search will always be performed and the LLM will always receive the
corresponding context to respond to the query. Defaults to False.
"""
self.top_k: int = top_k
self.chunk_size: int = chunk_size
self.chunk_overlap: int = chunk_overlap
self.use_tools: bool = use_tools
self.intro_prompt: str = (
"You are FlexBot, an assistant for question-answering tasks. "
)
self.rag_prompt: str = (
"You are FlexBot, an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use six sentences maximum and keep the "
"answer concise."
)
self.llm: ChatOpenAI
self.embeddings: OpenAIEmbeddings
self.vector_store: InMemoryVectorStore
self.prompt: PromptTemplate
self.graph: StateGraph
self.llm_model_name: str
self.llm_api_key: str
self.llm_url: str
self.embeddings_model_name: str
self.embeddings_api_key: str
self.embeddings_url: str
self._set_endpoint_config()
self._set_models()
self._set_vector_store()
self._set_graph()
def add_documents(self, file_paths: List[str]) -> None:
# parse documents into chunks
loader = MixedFileTypeLoader(file_paths)
docs = loader.load()
for doc in docs:
# only keep document filename
doc.metadata["source"] = os.path.basename(doc.metadata["source"])
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
all_splits = text_splitter.split_documents(docs)
# index chunked documents
_ = self.vector_store.add_documents(documents=all_splits)
def _check_env(self) -> None:
if "LLM_API_KEY" not in os.environ:
raise ValueError("Please set the LLM_API_KEY environment variable.")
if "LLM_URL" not in os.environ:
raise ValueError("Please set the LLM_URL environment variable.")
if "LLM_MODEL_NAME" not in os.environ:
raise ValueError("Please set the LLM_MODEL_NAME environment variable.")
if "EMBEDDINGS_API_KEY" not in os.environ:
raise ValueError("Please set the EMBEDDINGS_API_KEY environment variable.")
if "EMBEDDINGS_URL" not in os.environ:
raise ValueError("Please set the EMBEDDINGS_URL environment variable.")
if "EMBEDDINGS_MODEL_NAME" not in os.environ:
raise ValueError(
"Please set the EMBEDDINGS_MODEL_NAME environment variable."
)
def get_endpoint_config(self) -> dict:
return {
"llm_model_name": self.llm_model_name,
"llm_api_key": self.llm_api_key,
"llm_url": self.llm_url,
"embeddings_model_name": self.embeddings_model_name,
"embeddings_api_key": self.embeddings_api_key,
"embeddings_url": self.embeddings_url,
}
def set_endpoint_config(self, config: dict) -> None:
for key, value in config.items():
if key not in [
"llm_model_name",
"llm_api_key",
"llm_url",
"embeddings_model_name",
"embeddings_api_key",
"embeddings_url",
]:
raise ValueError(f"Invalid config key: {key}")
setattr(self, key, value)
self._set_models()
self._set_vector_store()
def _set_endpoint_config(
self,
) -> None:
self._check_env()
self.llm_model_name = os.getenv("LLM_MODEL_NAME")
self.llm_api_key = os.getenv("LLM_API_KEY")
self.llm_url = os.getenv("LLM_URL")
self.embeddings_model_name = os.getenv("EMBEDDINGS_MODEL_NAME")
self.embeddings_api_key = os.getenv("EMBEDDINGS_API_KEY")
self.embeddings_url = os.getenv("EMBEDDINGS_URL")
def _set_models(self) -> None:
config = AutoConfig.from_pretrained(self.embeddings_model_name)
assert self.chunk_size <= config.max_position_embeddings
llm = ChatOpenAI(
model_name=self.llm_model_name,
openai_api_key=self.llm_api_key,
openai_api_base=self.llm_url + "/v1",
)
embeddings = OpenAIEmbeddings(
model=self.embeddings_model_name,
deployment=self.embeddings_model_name,
openai_api_key=self.embeddings_api_key,
openai_api_base=self.embeddings_url + "/v1",
tiktoken_enabled=False,
)
self.llm = llm
self.embeddings = embeddings
def _set_vector_store(self) -> None:
self.vector_store = InMemoryVectorStore(self.embeddings)
def _get_generate_prompt(
self, state: MessagesState, docs_content: str
) -> List[BaseMessage]:
# Format into prompt
system_message_content = f"{self.rag_prompt}" "\n\n" f"{docs_content}"
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
return prompt
def _serialize_docs(self, docs: List[Document]) -> str:
"""Serialize documents."""
serialized_docs = []
for doc in docs:
if "source" in doc.metadata:
# remove "source" key from metadata and display it first
src_string = doc.metadata.pop("source")
if doc.metadata:
src_string += f" {doc.metadata}"
else:
src_string = str(doc.metadata)
serialized_docs.append(
f"Source: {src_string}\n" f"Content: {doc.page_content}"
)
return "\n\n".join(serialized_docs)
def _set_graph(self) -> None:
if self.use_tools:
self._set_graph_with_tool_calling()
else:
self._set_graph_without_tool_calling()
def _set_graph_with_tool_calling(self) -> None:
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Use this tool to retrieve information from documents stored in the knowledge base."""
retrieved_docs = self.vector_store.similarity_search(query, k=self.top_k)
serialized = self._serialize_docs(retrieved_docs)
# return content, artifact
return serialized, retrieved_docs
def query_or_respond(state: MessagesState):
"""Generate tool call for retrieval or respond."""
llm_with_tools = self.llm.bind_tools([retrieve])
prompt = [SystemMessage(self.intro_prompt)] + state["messages"]
response = llm_with_tools.invoke(prompt)
# MessagesState appends messages to state instead of overwriting
return {"messages": [response]}
def generate(state: MessagesState):
"""Generate answer."""
# Get context from generated ToolMessages
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1]
docs_content = "\n\n".join(doc.content for doc in tool_messages)
prompt = self._get_generate_prompt(state, docs_content)
# Run
response = self.llm.invoke(prompt)
return {"messages": [response]}
graph_builder = StateGraph(MessagesState)
memory = MemorySaver()
tools = ToolNode([retrieve])
# Node 1: Generate an AIMessage that may include a tool-call to be sent.
graph_builder.add_node(query_or_respond)
# Node 2: Execute the retrieval tool.
graph_builder.add_node(tools)
# Node 3: Generate a response using the retrieved content.
graph_builder.add_node(generate)
graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
"query_or_respond",
tools_condition,
{END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)
self.graph = graph_builder.compile(checkpointer=memory)
def _set_graph_without_tool_calling(self) -> None:
class ContextState(MessagesState):
context: List[Document]
def retrieve(state: ContextState):
# Get last HumanMessage, which is the question
for message in reversed(state["messages"]):
if message.type == "human":
question = message.content
break
else:
raise ValueError("No human message found in the state.")
retrieved_docs = self.vector_store.similarity_search(question)
return {"context": retrieved_docs}
def generate(state: ContextState):
"""Generate answer."""
docs_content = self._serialize_docs(state["context"])
prompt = self._get_generate_prompt(state, docs_content)
# Run
response = self.llm.invoke(prompt)
return {"messages": [response]}
graph_builder = StateGraph(ContextState).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
memory = MemorySaver()
self.graph = graph_builder.compile(checkpointer=memory)
def query(
self, input_message: str, thread_id: str = None
) -> dict[str, List[BaseMessage]]:
if thread_id is None:
logging.warning("thread_id is NONE")
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}
res = self.graph.invoke(
{"messages": [{"role": "user", "content": input_message}]},
config=config,
)
return res
def clear_vector_store(self) -> None:
"""Clear the vector store."""
logging.info("Clear the vector store")
self._set_vector_store()
src/reader.py
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT
import os
from langchain.document_loaders.base import BaseLoader
from langchain_community.document_loaders import (
BSHTMLLoader,
Docx2txtLoader,
PyPDFLoader,
WebBaseLoader,
)
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders.text import TextLoader
class MixedFileTypeLoader(BaseLoader):
def __init__(self, file_path):
self.file_path = file_path
def load(self):
if isinstance(self.file_path, list):
documents = []
for file in self.file_path:
docs = self._load_file(file)
if isinstance(docs, list):
documents.extend(docs)
else:
documents.append(docs)
return documents
return self._load_file(self.file_path)
def _load_file(self, file):
if file.startswith("http://") or file.startswith("https://"):
is_url = True
else:
is_url = False
if not os.path.exists(file):
raise FileNotFoundError(f"File not found: {file}")
file_extension = os.path.splitext(file)[1].lower()
if is_url:
return self._load_web(file)
elif file_extension == ".pdf":
return self._load_pdf(file)
elif file_extension in [".doc", ".docx"]:
return self._load_word(file)
elif file_extension == ".txt":
return self._load_txt(file)
elif file_extension in [".html", ".htm"]:
return self._load_html(file)
elif file_extension == ".csv":
return self._load_csv(file)
else:
raise ValueError(f"Unsupported file type: {file_extension}")
def _load_pdf(self, file_path):
loader = PyPDFLoader(file_path)
return loader.load()
def _load_word(self, file_path):
loader = Docx2txtLoader(file_path)
return loader.load()
def _load_txt(self, file_path):
loader = TextLoader(file_path)
return loader.load()
def _load_html(self, file_path):
loader = BSHTMLLoader(file_path)
return loader.load()
def _load_csv(self, file_path):
loader = CSVLoader(file_path)
return loader.load()
def _load_web(self, file_path):
loader = WebBaseLoader(file_path)
return loader.load()