memoria-api / main.py
rohitkshirsagar19's picture
Update main.py
74bf5bb verified
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec
import uuid
import os
from contextlib import asynccontextmanager
# --- Environment Setup ---
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memoria-index")
CACHE_DIR = "/app/model_cache" # For Hugging Face caching
# --- Global Objects ---
model = None
pc = None
index = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, pc, index
print("Application startup...")
if not PINECONE_API_KEY:
raise ValueError("PINECONE_API_KEY environment variable not set.")
# 1. Load the official, industry-standard lightweight model.
print("Loading sentence-transformers/all-MiniLM-L6-v2 model...")
model = SentenceTransformer(
'sentence-transformers/all-MiniLM-L6-v2',
cache_folder=CACHE_DIR
)
print("Model loaded.")
# 2. Connect to Pinecone
print("Connecting to Pinecone...")
pc = Pinecone(api_key=PINECONE_API_KEY)
# 3. Get or create the Pinecone index with the correct dimension.
model_dimension = model.get_sentence_embedding_dimension()
print(f"Model dimension is: {model_dimension}")
if PINECONE_INDEX_NAME not in pc.list_indexes().names():
print(f"Creating new Pinecone index: {PINECONE_INDEX_NAME} with dimension {model_dimension}")
pc.create_index(
name=PINECONE_INDEX_NAME,
dimension=model_dimension,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1")
)
index = pc.Index(PINECONE_INDEX_NAME)
print("Pinecone setup complete.")
yield
print("Application shutdown.")
# --- Pydantic Models & FastAPI App ---
class Memory(BaseModel):
content: str
class SearchQuery(BaseModel):
query: str
app = FastAPI(
title="Memoria API",
version="1.1.0",
lifespan=lifespan
)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# --- API Endpoints ---
@app.get("/")
def read_root():
return {"status": "ok", "message": "Welcome to the Memoria API!"}
@app.post("/save_memory")
def save_memory_endpoint(memory: Memory):
embedding = model.encode(memory.content).tolist()
memory_id = str(uuid.uuid4())
index.upsert(vectors=[{"id": memory_id, "values": embedding, "metadata": {"text": memory.content}}])
print(f"Saved memory: {memory_id}")
return {"status": "success", "id": memory_id}
@app.post("/search_memory")
def search_memory_endpoint(search: SearchQuery):
query_embedding = model.encode(search.query).tolist()
results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
retrieved_documents = [match['metadata']['text'] for match in results['matches']]
print(f"Found {len(retrieved_documents)} results for query: '{search.query}'")
return {"status": "success", "results": retrieved_documents}
if __name__ == "__main__":
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)