rohitkshirsagar19 commited on
Commit
74bf5bb
·
verified ·
1 Parent(s): ee9f80e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -50
main.py CHANGED
@@ -1,7 +1,7 @@
1
  import uvicorn
2
  from fastapi import FastAPI, HTTPException
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
 
5
  from sentence_transformers import SentenceTransformer
6
  from pinecone import Pinecone, ServerlessSpec
7
  import uuid
@@ -11,31 +11,25 @@ from contextlib import asynccontextmanager
11
  # --- Environment Setup ---
12
  PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
13
  PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memoria-index")
14
- # Define a writable cache directory inside our container
15
- CACHE_DIR = "/app/model_cache"
16
 
17
- # --- Global objects ---
18
  model = None
19
  pc = None
20
  index = None
21
 
22
  @asynccontextmanager
23
  async def lifespan(app: FastAPI):
24
- """
25
- Handles startup and shutdown events for the FastAPI app.
26
- Loads the model and connects to Pinecone on startup.
27
- """
28
  global model, pc, index
29
  print("Application startup...")
30
 
31
  if not PINECONE_API_KEY:
32
  raise ValueError("PINECONE_API_KEY environment variable not set.")
33
 
34
- # 1. Load the AI Model
35
- print(f"Loading model and setting cache to: {CACHE_DIR}")
36
- # THE FINAL FIX: Explicitly tell the library where to save the model.
37
  model = SentenceTransformer(
38
- 'sentence-transformers/paraphrase-albert-small-v2',
39
  cache_folder=CACHE_DIR
40
  )
41
  print("Model loaded.")
@@ -44,12 +38,15 @@ async def lifespan(app: FastAPI):
44
  print("Connecting to Pinecone...")
45
  pc = Pinecone(api_key=PINECONE_API_KEY)
46
 
47
- # 3. Get or create the Pinecone index
 
 
 
48
  if PINECONE_INDEX_NAME not in pc.list_indexes().names():
49
- print(f"Creating new Pinecone index: {PINECONE_INDEX_NAME}")
50
  pc.create_index(
51
  name=PINECONE_INDEX_NAME,
52
- dimension=model.get_sentence_embedding_dimension(),
53
  metric="cosine",
54
  spec=ServerlessSpec(cloud="aws", region="us-east-1")
55
  )
@@ -58,30 +55,18 @@ async def lifespan(app: FastAPI):
58
  yield
59
  print("Application shutdown.")
60
 
61
- # ... (The rest of the file remains exactly the same) ...
62
-
63
- # --- Pydantic Models ---
64
  class Memory(BaseModel):
65
  content: str
66
-
67
  class SearchQuery(BaseModel):
68
  query: str
69
 
70
- # --- FastAPI App ---
71
  app = FastAPI(
72
  title="Memoria API",
73
- description="API for storing and retrieving memories.",
74
- version="1.0.1", # Final deployed version
75
  lifespan=lifespan
76
  )
77
-
78
- app.add_middleware(
79
- CORSMiddleware,
80
- allow_origins=["*"],
81
- allow_credentials=True,
82
- allow_methods=["*"],
83
- allow_headers=["*"],
84
- )
85
 
86
  # --- API Endpoints ---
87
  @app.get("/")
@@ -89,28 +74,20 @@ def read_root():
89
  return {"status": "ok", "message": "Welcome to the Memoria API!"}
90
 
91
  @app.post("/save_memory")
92
- def save_memory(memory: Memory):
93
- try:
94
- embedding = model.encode(memory.content).tolist()
95
- memory_id = str(uuid.uuid4())
96
- index.upsert(vectors=[{"id": memory_id, "values": embedding, "metadata": {"text": memory.content}}])
97
- print(f"Successfully saved memory with ID: {memory_id}")
98
- return {"status": "success", "id": memory_id}
99
- except Exception as e:
100
- print(f"An error occurred during save: {e}")
101
- raise HTTPException(status_code=500, detail=str(e))
102
 
103
  @app.post("/search_memory")
104
- def search_memory(search: SearchQuery):
105
- try:
106
- query_embedding = model.encode(search.query).tolist()
107
- results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
108
- retrieved_documents = [match['metadata']['text'] for match in results['matches']]
109
- print(f"Found {len(retrieved_documents)} results for query: '{search.query}'")
110
- return {"status": "success", "results": retrieved_documents}
111
- except Exception as e:
112
- print(f"An error occurred during search: {e}")
113
- raise HTTPException(status_code=500, detail=str(e))
114
 
115
  if __name__ == "__main__":
116
  uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
 
1
  import uvicorn
2
  from fastapi import FastAPI, HTTPException
 
3
  from pydantic import BaseModel
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from sentence_transformers import SentenceTransformer
6
  from pinecone import Pinecone, ServerlessSpec
7
  import uuid
 
11
  # --- Environment Setup ---
12
  PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
13
  PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memoria-index")
14
+ CACHE_DIR = "/app/model_cache" # For Hugging Face caching
 
15
 
16
+ # --- Global Objects ---
17
  model = None
18
  pc = None
19
  index = None
20
 
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
 
 
 
 
23
  global model, pc, index
24
  print("Application startup...")
25
 
26
  if not PINECONE_API_KEY:
27
  raise ValueError("PINECONE_API_KEY environment variable not set.")
28
 
29
+ # 1. Load the official, industry-standard lightweight model.
30
+ print("Loading sentence-transformers/all-MiniLM-L6-v2 model...")
 
31
  model = SentenceTransformer(
32
+ 'sentence-transformers/all-MiniLM-L6-v2',
33
  cache_folder=CACHE_DIR
34
  )
35
  print("Model loaded.")
 
38
  print("Connecting to Pinecone...")
39
  pc = Pinecone(api_key=PINECONE_API_KEY)
40
 
41
+ # 3. Get or create the Pinecone index with the correct dimension.
42
+ model_dimension = model.get_sentence_embedding_dimension()
43
+ print(f"Model dimension is: {model_dimension}")
44
+
45
  if PINECONE_INDEX_NAME not in pc.list_indexes().names():
46
+ print(f"Creating new Pinecone index: {PINECONE_INDEX_NAME} with dimension {model_dimension}")
47
  pc.create_index(
48
  name=PINECONE_INDEX_NAME,
49
+ dimension=model_dimension,
50
  metric="cosine",
51
  spec=ServerlessSpec(cloud="aws", region="us-east-1")
52
  )
 
55
  yield
56
  print("Application shutdown.")
57
 
58
+ # --- Pydantic Models & FastAPI App ---
 
 
59
  class Memory(BaseModel):
60
  content: str
 
61
  class SearchQuery(BaseModel):
62
  query: str
63
 
 
64
  app = FastAPI(
65
  title="Memoria API",
66
+ version="1.1.0",
 
67
  lifespan=lifespan
68
  )
69
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
70
 
71
  # --- API Endpoints ---
72
  @app.get("/")
 
74
  return {"status": "ok", "message": "Welcome to the Memoria API!"}
75
 
76
  @app.post("/save_memory")
77
+ def save_memory_endpoint(memory: Memory):
78
+ embedding = model.encode(memory.content).tolist()
79
+ memory_id = str(uuid.uuid4())
80
+ index.upsert(vectors=[{"id": memory_id, "values": embedding, "metadata": {"text": memory.content}}])
81
+ print(f"Saved memory: {memory_id}")
82
+ return {"status": "success", "id": memory_id}
 
 
 
 
83
 
84
  @app.post("/search_memory")
85
+ def search_memory_endpoint(search: SearchQuery):
86
+ query_embedding = model.encode(search.query).tolist()
87
+ results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
88
+ retrieved_documents = [match['metadata']['text'] for match in results['matches']]
89
+ print(f"Found {len(retrieved_documents)} results for query: '{search.query}'")
90
+ return {"status": "success", "results": retrieved_documents}
 
 
 
 
91
 
92
  if __name__ == "__main__":
93
  uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)