Spaces:
Sleeping
Sleeping
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from sentence_transformers import SentenceTransformer | |
from pinecone import Pinecone, ServerlessSpec | |
import uuid | |
import os | |
from contextlib import asynccontextmanager | |
# --- Environment Setup --- | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memoria-index") | |
CACHE_DIR = "/app/model_cache" # For Hugging Face caching | |
# --- Global Objects --- | |
model = None | |
pc = None | |
index = None | |
async def lifespan(app: FastAPI): | |
global model, pc, index | |
print("Application startup...") | |
if not PINECONE_API_KEY: | |
raise ValueError("PINECONE_API_KEY environment variable not set.") | |
# 1. Load the official, industry-standard lightweight model. | |
print("Loading sentence-transformers/all-MiniLM-L6-v2 model...") | |
model = SentenceTransformer( | |
'sentence-transformers/all-MiniLM-L6-v2', | |
cache_folder=CACHE_DIR | |
) | |
print("Model loaded.") | |
# 2. Connect to Pinecone | |
print("Connecting to Pinecone...") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
# 3. Get or create the Pinecone index with the correct dimension. | |
model_dimension = model.get_sentence_embedding_dimension() | |
print(f"Model dimension is: {model_dimension}") | |
if PINECONE_INDEX_NAME not in pc.list_indexes().names(): | |
print(f"Creating new Pinecone index: {PINECONE_INDEX_NAME} with dimension {model_dimension}") | |
pc.create_index( | |
name=PINECONE_INDEX_NAME, | |
dimension=model_dimension, | |
metric="cosine", | |
spec=ServerlessSpec(cloud="aws", region="us-east-1") | |
) | |
index = pc.Index(PINECONE_INDEX_NAME) | |
print("Pinecone setup complete.") | |
yield | |
print("Application shutdown.") | |
# --- Pydantic Models & FastAPI App --- | |
class Memory(BaseModel): | |
content: str | |
class SearchQuery(BaseModel): | |
query: str | |
app = FastAPI( | |
title="Memoria API", | |
version="1.1.0", | |
lifespan=lifespan | |
) | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
# --- API Endpoints --- | |
def read_root(): | |
return {"status": "ok", "message": "Welcome to the Memoria API!"} | |
def save_memory_endpoint(memory: Memory): | |
embedding = model.encode(memory.content).tolist() | |
memory_id = str(uuid.uuid4()) | |
index.upsert(vectors=[{"id": memory_id, "values": embedding, "metadata": {"text": memory.content}}]) | |
print(f"Saved memory: {memory_id}") | |
return {"status": "success", "id": memory_id} | |
def search_memory_endpoint(search: SearchQuery): | |
query_embedding = model.encode(search.query).tolist() | |
results = index.query(vector=query_embedding, top_k=5, include_metadata=True) | |
retrieved_documents = [match['metadata']['text'] for match in results['matches']] | |
print(f"Found {len(retrieved_documents)} results for query: '{search.query}'") | |
return {"status": "success", "results": retrieved_documents} | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True) |