SEALION-v3.5-8B-R-RAG / src /simple_rag.py
Darayut's picture
Update src/simple_rag.py
8a7f3b7 verified
# Modified RAG Pipeline for General Document Q&A (Khmer & English)
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFDirectoryLoader
from openai import OpenAI
logging.basicConfig(level=logging.INFO)
use_gpu = torch.cuda.is_available()
if use_gpu:
print("CUDA device in use:", torch.cuda.get_device_name(0))
else:
print("Running on CPU. No GPU detected.")
# Load API key from HF Space secrets
SEALION_API_KEY = os.environ.get("SEALION_API_KEY")
client = OpenAI(
api_key=SEALION_API_KEY,
base_url="https://api.sea-lion.ai/v1"
)
# Use Hugging Face's writable directory
WRITABLE_DIR = os.environ.get("HOME", "/app")
DATA_PATH = os.path.join(WRITABLE_DIR, "src", "data")
CHROMA_PATH = os.path.join(WRITABLE_DIR, "src", "chroma")
embedding_model = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-base")
# Generic assistant prompt for dual Khmer/English
PROMPT_TEMPLATE = """
Respond ONLY in the same language as the question.
If the question is in English, answer in English.
If the question is in Khmer, answer in Khmer.
You are a helpful assistant.
Use only the provided context below to answer the question.
Do not mention the context or that you used it.
Context:
{context}
Question:
{question}
Answer:
"""
def load_documents():
loader = PyPDFDirectoryLoader(DATA_PATH)
return loader.load()
def split_text(documents: list[Document]):
splitter = RecursiveCharacterTextSplitter(
chunk_size=512, chunk_overlap=100, length_function=len, add_start_index=True
)
chunks = splitter.split_documents(documents)
logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")
return chunks
def save_to_chroma(chunks: list[Document]):
if os.path.exists(CHROMA_PATH):
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_model)
db.add_documents(chunks)
logging.info("Added documents to existing Chroma DB.")
else:
db = Chroma.from_documents(
chunks, embedding_model, persist_directory=CHROMA_PATH
)
logging.info("Created new Chroma DB.")
db.persist()
logging.info(f"Saved {len(chunks)} chunks to Chroma.")
def generate_data_store():
documents = load_documents()
chunks = split_text(documents)
save_to_chroma(chunks)
def ask_question(query_text: str, k: int = 3):
logging.info("Processing user question...")
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_model)
results = db.similarity_search(query_text, k=k)
context_chunks = []
for doc in results:
meta = doc.metadata or {}
context_chunks.append({
"filename": os.path.basename(meta.get("source", "unknown.pdf")),
"page": meta.get("page", 1),
"text": doc.page_content.strip()
})
context_text = "\n\n".join(chunk["text"] for chunk in context_chunks)
prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text)
messages = [{"role": "user", "content": prompt}]
logging.info("Sending prompt to model...")
try:
logging.info("Sending prompt to SEA-LION API...")
completion = client.chat.completions.create(
model="aisingapore/Llama-SEA-LION-v3.5-8B-R",
messages=messages,
extra_body={
"chat_template_kwargs": {
"thinking_mode": "off"
},
"cache": {
"no-cache": True
}
},
max_tokens=512
)
answer = completion.choices[0].message.content.strip()
except Exception as e:
logging.error(f"Error calling SEA-LION API: {e}")
answer = "Sorry, something went wrong when contacting the language model."
return answer, context_chunks