Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, HTTPException | |
from langchain.prompts import PromptTemplate | |
from pydantic import BaseModel | |
from typing import Optional | |
from dotenv import load_dotenv | |
from embeddings.embeddings import generate_embeddings | |
from elastic.retrieval import search_certification_chunks | |
from prompting.rewrite_question import classify_certification, initialize_llms, process_query | |
load_dotenv() | |
app = FastAPI( | |
title="Hydrogen Certification RAG System", | |
description="API for querying hydrogen certification documents using RAG", | |
version="0.1.0" | |
) | |
# Initialize LLMs and Elasticsearch client | |
llms = initialize_llms() | |
# Request models | |
class QueryRequest(BaseModel): | |
query: str | |
llm = initialize_llms()["rewrite_llm"] | |
# Endpoints | |
async def handle_query(request: QueryRequest): | |
""" | |
Process a query through the full RAG pipeline: | |
1. Classify certification (if not provided) | |
2. Optimize query based on specificity | |
3. Search relevant chunks | |
""" | |
try: | |
# Step 1: Determine certification | |
query = request.query | |
certification = classify_certification(request.query, llms["rewrite_llm"]) | |
if "no certification mentioned" in certification : | |
raise HTTPException( | |
status_code=400, | |
detail="No certification specified in query and none provided" | |
) | |
# Step 2: Process query | |
processed_query = process_query(request.query, llms) | |
question_vector = generate_embeddings(processed_query) | |
# Step 3: Search | |
results = search_certification_chunks( | |
index_name="certif_index", | |
certification_name=certification, | |
text_query=processed_query, | |
vector_query=question_vector, | |
) | |
results_ = search_certification_chunks( | |
index_name="certification_index", | |
certification_name=certification, | |
text_query=processed_query, | |
vector_query=question_vector, | |
) | |
results_list = [result["text"] for result in results] | |
results_list_ = [result["text"] for result in results_] | |
results_merged = ". ".join([result["text"] for result in results]) | |
results_merged_ = ". ".join([result["text"] for result in results_]) | |
template = """ | |
You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification. | |
Provide a clear, concise response that directly addresses the question without unnecessary information. | |
Question: {question} | |
Certification: {certification} | |
Context: {context} | |
Answer: | |
""" | |
prompt = PromptTemplate( | |
input_variables=["question", "certification", "context"], | |
template=template | |
) | |
chain = prompt | llm | |
answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content | |
answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content | |
return { | |
"certification": certification, | |
"certif_index": answer, | |
"certification_index": answer_, | |
"context_certif": results_list, | |
"context_certifications": results_list_ | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_certifications(): | |
"""List all available certifications""" | |
try: | |
certs_dir = "docs/processed" | |
return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))] | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |