Spaces:
Sleeping
Sleeping
Commit
Β·
44a2e1d
0
Parent(s):
Initial commit - Pregnancy RAG Chatbot
Browse files- .gitattributes +35 -0
- README.md +14 -0
- app.py +477 -0
- rag_functions.py +246 -0
- requirements.txt +0 -0
- utils.py +184 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Pregnancy RAG Chatbot
|
3 |
+
emoji: π
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.35.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Pregnancy Risk Assessment AI Chatbot
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from datetime import datetime
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
|
8 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
9 |
+
|
10 |
+
|
11 |
+
from rag_functions import get_direct_answer, get_answer_with_query_engine
|
12 |
+
from utils import get_index
|
13 |
+
print("β
Successfully imported RAG functions")
|
14 |
+
|
15 |
+
class PregnancyRiskAgent:
|
16 |
+
def __init__(self):
|
17 |
+
self.conversation_history = []
|
18 |
+
self.current_symptoms = {}
|
19 |
+
self.risk_assessment_done = False
|
20 |
+
self.user_context = {}
|
21 |
+
self.last_user_query = ""
|
22 |
+
|
23 |
+
|
24 |
+
self.symptom_questions = [
|
25 |
+
"Are you currently experiencing any unusual bleeding or discharge?",
|
26 |
+
"How would you describe your baby's movements today compared to yesterday?",
|
27 |
+
"Have you had any headaches that won't go away or that affect your vision?",
|
28 |
+
"Do you feel any pressure or pain in your pelvis or lower back?",
|
29 |
+
"Are you experiencing any other symptoms? (If yes, please describe briefly)"
|
30 |
+
]
|
31 |
+
|
32 |
+
self.current_question_index = 0
|
33 |
+
self.waiting_for_first_response = True
|
34 |
+
|
35 |
+
def add_to_conversation_history(self, role, message):
|
36 |
+
self.conversation_history.append({
|
37 |
+
"role": role,
|
38 |
+
"message": message,
|
39 |
+
"timestamp": datetime.now().isoformat()
|
40 |
+
})
|
41 |
+
|
42 |
+
|
43 |
+
if len(self.conversation_history) > 20:
|
44 |
+
self.conversation_history = self.conversation_history[-20:]
|
45 |
+
|
46 |
+
def get_conversation_context(self):
|
47 |
+
context_parts = []
|
48 |
+
|
49 |
+
recent_history = self.conversation_history[-10:]
|
50 |
+
|
51 |
+
for entry in recent_history:
|
52 |
+
if entry["role"] == "user":
|
53 |
+
context_parts.append(f"User: {entry['message']}")
|
54 |
+
else:
|
55 |
+
context_parts.append(f"Assistant: {entry['message'][:200]}...")
|
56 |
+
|
57 |
+
return "\n".join(context_parts)
|
58 |
+
|
59 |
+
def is_follow_up_question(self, user_input):
|
60 |
+
follow_up_indicators = [
|
61 |
+
"what about", "can you explain", "what does", "why", "how",
|
62 |
+
"tell me more", "what should i", "is it normal", "should i be worried",
|
63 |
+
"what if", "when should", "how long", "what causes", "is this"
|
64 |
+
]
|
65 |
+
|
66 |
+
user_lower = user_input.lower()
|
67 |
+
return any(indicator in user_lower for indicator in follow_up_indicators)
|
68 |
+
|
69 |
+
def process_user_input(self, user_input, chat_history):
|
70 |
+
try:
|
71 |
+
self.last_user_query = user_input
|
72 |
+
self.add_to_conversation_history("user", user_input)
|
73 |
+
|
74 |
+
|
75 |
+
if self.waiting_for_first_response:
|
76 |
+
self.current_symptoms[f"question_0"] = user_input
|
77 |
+
self.waiting_for_first_response = False
|
78 |
+
self.current_question_index = 1
|
79 |
+
|
80 |
+
if self.current_question_index < len(self.symptom_questions):
|
81 |
+
bot_response = f"{self.symptom_questions[self.current_question_index]}"
|
82 |
+
else:
|
83 |
+
bot_response = self.provide_risk_assessment()
|
84 |
+
self.risk_assessment_done = True
|
85 |
+
|
86 |
+
self.add_to_conversation_history("assistant", bot_response)
|
87 |
+
return bot_response
|
88 |
+
|
89 |
+
|
90 |
+
elif self.current_question_index < len(self.symptom_questions) and not self.risk_assessment_done:
|
91 |
+
self.current_symptoms[f"question_{self.current_question_index}"] = user_input
|
92 |
+
self.current_question_index += 1
|
93 |
+
|
94 |
+
if self.current_question_index < len(self.symptom_questions):
|
95 |
+
bot_response = f"{self.symptom_questions[self.current_question_index]}"
|
96 |
+
else:
|
97 |
+
bot_response = self.provide_risk_assessment()
|
98 |
+
self.risk_assessment_done = True
|
99 |
+
|
100 |
+
self.add_to_conversation_history("assistant", bot_response)
|
101 |
+
return bot_response
|
102 |
+
|
103 |
+
|
104 |
+
else:
|
105 |
+
bot_response = self.handle_follow_up_conversation(user_input)
|
106 |
+
self.add_to_conversation_history("assistant", bot_response)
|
107 |
+
return bot_response
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
print(f"β Error in process_user_input: {e}")
|
111 |
+
traceback.print_exc()
|
112 |
+
error_response = "I encountered an error. Please try again or consult your healthcare provider."
|
113 |
+
self.add_to_conversation_history("assistant", error_response)
|
114 |
+
return error_response
|
115 |
+
|
116 |
+
def handle_follow_up_conversation(self, user_input):
|
117 |
+
try:
|
118 |
+
print(f"π Processing follow-up question: {user_input}")
|
119 |
+
|
120 |
+
symptom_summary = self.create_symptom_summary()
|
121 |
+
conversation_context = self.get_conversation_context()
|
122 |
+
|
123 |
+
if any(word in user_input.lower() for word in ["last", "previous", "what did i ask", "my question"]):
|
124 |
+
if self.last_user_query:
|
125 |
+
return f"Your last question was: \"{self.last_user_query}\"\n\nWould you like me to elaborate on that topic or do you have a different question?"
|
126 |
+
else:
|
127 |
+
return "I don't have a record of your previous question. Could you please rephrase what you'd like to know?"
|
128 |
+
|
129 |
+
rag_response = get_direct_answer(user_input, symptom_summary, conversation_context=conversation_context, is_risk_assessment=False)
|
130 |
+
|
131 |
+
if "Error" in rag_response or len(rag_response) < 50:
|
132 |
+
print("π Trying alternative method...")
|
133 |
+
rag_response = get_answer_with_query_engine(user_input)
|
134 |
+
|
135 |
+
bot_response = f"""Based on your symptoms and medical literature:
|
136 |
+
|
137 |
+
{rag_response}"""
|
138 |
+
|
139 |
+
return bot_response
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
print(f"β Error in follow-up conversation: {e}")
|
143 |
+
return "I encountered an error processing your question. Could you please rephrase it or consult your healthcare provider?"
|
144 |
+
|
145 |
+
def create_symptom_summary(self):
|
146 |
+
if not self.current_symptoms:
|
147 |
+
return "No specific symptoms reported yet"
|
148 |
+
|
149 |
+
summary_parts = []
|
150 |
+
for i, (key, response) in enumerate(self.current_symptoms.items()):
|
151 |
+
if i < len(self.symptom_questions):
|
152 |
+
question = self.symptom_questions[i]
|
153 |
+
summary_parts.append(f"{question}: {response}")
|
154 |
+
return "\n".join(summary_parts)
|
155 |
+
|
156 |
+
def parse_risk_level(self, text):
|
157 |
+
import re
|
158 |
+
|
159 |
+
patterns = [
|
160 |
+
r'\*\*Risk Level:\*\*\s*(Low|Medium|High)',
|
161 |
+
r'Risk Level:\s*\*\*(Low|Medium|High)\*\*',
|
162 |
+
r'Risk Level:\s*(Low|Medium|High)',
|
163 |
+
r'\*\*Risk Level:\*\*\s*<(Low|Medium|High)>',
|
164 |
+
r'Risk Level.*?<(Low|Medium|High)>',
|
165 |
+
]
|
166 |
+
|
167 |
+
for pattern in patterns:
|
168 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
169 |
+
if match:
|
170 |
+
risk_level = match.group(1).capitalize()
|
171 |
+
print(f"β
Successfully parsed risk level: {risk_level}")
|
172 |
+
return risk_level
|
173 |
+
|
174 |
+
print(f"β Could not parse risk level from: {text[:200]}...")
|
175 |
+
return None
|
176 |
+
|
177 |
+
def provide_risk_assessment(self):
|
178 |
+
all_symptoms = self.create_symptom_summary()
|
179 |
+
|
180 |
+
rag_query = f"Analyze these pregnancy symptoms for risk assessment:\n{all_symptoms}\n\nProvide risk level and medical recommendations."
|
181 |
+
detailed_analysis = get_direct_answer(rag_query, all_symptoms, is_risk_assessment=True)
|
182 |
+
|
183 |
+
print(f"π RAG Response: {detailed_analysis[:300]}...")
|
184 |
+
|
185 |
+
llm_risk_level = self.parse_risk_level(detailed_analysis)
|
186 |
+
|
187 |
+
if llm_risk_level:
|
188 |
+
risk_level = llm_risk_level
|
189 |
+
|
190 |
+
if risk_level == "Low":
|
191 |
+
action = "β
Continue routine prenatal care and self-monitoring"
|
192 |
+
elif risk_level == "Medium":
|
193 |
+
action = "β οΈ Contact your doctor within 24 hours"
|
194 |
+
elif risk_level == "High":
|
195 |
+
action = "π¨ Immediate visit to ER or OB emergency care required"
|
196 |
+
else:
|
197 |
+
print("β οΈ RAG assessment failed, using fallback")
|
198 |
+
risk_level = "Medium"
|
199 |
+
action = "β οΈ Contact your doctor within 24 hours"
|
200 |
+
|
201 |
+
symptom_list = []
|
202 |
+
for i, (key, symptom) in enumerate(self.current_symptoms.items()):
|
203 |
+
question = self.symptom_questions[i] if i < len(self.symptom_questions) else f"Question {i+1}"
|
204 |
+
symptom_list.append(f"β’ **{question}**: {symptom}")
|
205 |
+
|
206 |
+
assessment = f"""
|
207 |
+
## π₯ **Risk Assessment Complete**
|
208 |
+
|
209 |
+
**Risk Level: {risk_level}**
|
210 |
+
**Recommended Action: {action}**
|
211 |
+
|
212 |
+
### π **Your Reported Symptoms:**
|
213 |
+
{chr(10).join(symptom_list)}
|
214 |
+
|
215 |
+
### π¬ **Medical Analysis:**
|
216 |
+
{detailed_analysis}
|
217 |
+
|
218 |
+
### π‘ **Next Steps:**
|
219 |
+
- Follow the recommended action above
|
220 |
+
- Keep monitoring your symptoms
|
221 |
+
- Contact your healthcare provider if symptoms worsen
|
222 |
+
- Feel free to ask me any follow-up questions about pregnancy health
|
223 |
+
|
224 |
+
"""
|
225 |
+
return assessment
|
226 |
+
|
227 |
+
def reset_conversation(self):
|
228 |
+
self.conversation_history = []
|
229 |
+
self.current_symptoms = {}
|
230 |
+
self.current_question_index = 0
|
231 |
+
self.risk_assessment_done = False
|
232 |
+
self.waiting_for_first_response = True
|
233 |
+
self.user_context = {}
|
234 |
+
self.last_user_query = ""
|
235 |
+
return get_welcome_message()
|
236 |
+
|
237 |
+
def get_welcome_message():
|
238 |
+
return """Hello! I'm here to help assess pregnancy-related symptoms and provide risk insights based on medical literature.
|
239 |
+
|
240 |
+
I'll ask you a few important questions about your current symptoms, then provide a risk assessment and recommendations. After that, feel free to ask any follow-up questions!
|
241 |
+
|
242 |
+
**To get started, please tell me:**
|
243 |
+
Are you currently experiencing any unusual bleeding or discharge?
|
244 |
+
|
245 |
+
---
|
246 |
+
β οΈ **Important**: This tool is for informational purposes only and should not replace professional medical care. In case of emergency, contact your healthcare provider immediately."""
|
247 |
+
|
248 |
+
|
249 |
+
def create_new_agent():
|
250 |
+
|
251 |
+
return PregnancyRiskAgent()
|
252 |
+
|
253 |
+
|
254 |
+
agent = create_new_agent()
|
255 |
+
|
256 |
+
def chat_interface_with_reset(user_input, history):
|
257 |
+
global agent
|
258 |
+
|
259 |
+
if user_input.lower() in ["reset", "restart", "new assessment"]:
|
260 |
+
agent = create_new_agent()
|
261 |
+
return get_welcome_message()
|
262 |
+
|
263 |
+
response = agent.process_user_input(user_input, history)
|
264 |
+
return response
|
265 |
+
|
266 |
+
def reset_chat():
|
267 |
+
global agent
|
268 |
+
agent = create_new_agent()
|
269 |
+
return [{"role": "assistant", "content": get_welcome_message()}], ""
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
custom_css = """
|
274 |
+
body, .gradio-container {
|
275 |
+
color: yellow !important;
|
276 |
+
}
|
277 |
+
|
278 |
+
.header {
|
279 |
+
background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 100%);
|
280 |
+
padding: 2rem;
|
281 |
+
border-radius: 1rem;
|
282 |
+
text-align: center;
|
283 |
+
margin-bottom: 2rem;
|
284 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
285 |
+
}
|
286 |
+
|
287 |
+
.header h1 {
|
288 |
+
color: black !important;
|
289 |
+
margin-bottom: 0.5rem;
|
290 |
+
font-size: 2.5rem;
|
291 |
+
}
|
292 |
+
|
293 |
+
.header p {
|
294 |
+
color: black !important;
|
295 |
+
font-size: 1.1rem;
|
296 |
+
margin: 0.5rem 0;
|
297 |
+
}
|
298 |
+
|
299 |
+
.warning {
|
300 |
+
background-color: #fff4e6;
|
301 |
+
border-left: 6px solid #ff7f00;
|
302 |
+
padding: 15px;
|
303 |
+
border-radius: 5px;
|
304 |
+
margin: 10px 0;
|
305 |
+
}
|
306 |
+
|
307 |
+
.warning h3 {
|
308 |
+
color: black !important;
|
309 |
+
margin-top: 0;
|
310 |
+
}
|
311 |
+
|
312 |
+
.warning p {
|
313 |
+
color: black !important;
|
314 |
+
line-height: 1.6;
|
315 |
+
}
|
316 |
+
|
317 |
+
div[style*="background-color: #e8f5e8"] {
|
318 |
+
color: black !important;
|
319 |
+
}
|
320 |
+
|
321 |
+
div[style*="background-color: #e8f5e8"] h3 {
|
322 |
+
color: black !important;
|
323 |
+
}
|
324 |
+
|
325 |
+
div[style*="background-color: #e8f5e8"] li {
|
326 |
+
color: black !important;
|
327 |
+
}
|
328 |
+
|
329 |
+
.chatbot {
|
330 |
+
color: black !important;
|
331 |
+
}
|
332 |
+
|
333 |
+
.message {
|
334 |
+
color: black !important;
|
335 |
+
}
|
336 |
+
|
337 |
+
/* Hide Gradio footer elements */
|
338 |
+
.footer {
|
339 |
+
display: none !important;
|
340 |
+
}
|
341 |
+
|
342 |
+
.gradio-container .footer {
|
343 |
+
display: none !important;
|
344 |
+
}
|
345 |
+
|
346 |
+
footer {
|
347 |
+
display: none !important;
|
348 |
+
}
|
349 |
+
|
350 |
+
.api-docs {
|
351 |
+
display: none !important;
|
352 |
+
}
|
353 |
+
|
354 |
+
.built-with {
|
355 |
+
display: none !important;
|
356 |
+
}
|
357 |
+
|
358 |
+
.gradio-container > .built-with {
|
359 |
+
display: none !important;
|
360 |
+
}
|
361 |
+
|
362 |
+
.settings {
|
363 |
+
display: none !important;
|
364 |
+
}
|
365 |
+
|
366 |
+
div[class*="footer"] {
|
367 |
+
display: none !important;
|
368 |
+
}
|
369 |
+
|
370 |
+
div[class*="built"] {
|
371 |
+
display: none !important;
|
372 |
+
}
|
373 |
+
|
374 |
+
*:contains("Built with Gradio") {
|
375 |
+
display: none !important;
|
376 |
+
}
|
377 |
+
|
378 |
+
*:contains("Use via API") {
|
379 |
+
display: none !important;
|
380 |
+
}
|
381 |
+
|
382 |
+
*:contains("Settings") {
|
383 |
+
display: none !important;
|
384 |
+
}
|
385 |
+
"""
|
386 |
+
|
387 |
+
|
388 |
+
with gr.Blocks(css=custom_css) as demo:
|
389 |
+
gr.HTML("""
|
390 |
+
<div class="header">
|
391 |
+
<h1>π€± Pregnancy RAG Chatbot</h1>
|
392 |
+
<p><strong style="color: black !important;">Proactive RAG-powered pregnancy risk management</strong></p>
|
393 |
+
</div>
|
394 |
+
""")
|
395 |
+
|
396 |
+
with gr.Row():
|
397 |
+
with gr.Column(scale=1):
|
398 |
+
gr.HTML("""
|
399 |
+
<div class="warning">
|
400 |
+
<h3>β οΈ Medical Disclaimer</h3>
|
401 |
+
<p>This AI assistant provides information based on medical literature but is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
|
402 |
+
<p><strong style="color: black !important;">In emergencies, call emergency services immediately.</strong></p>
|
403 |
+
</div>
|
404 |
+
""")
|
405 |
+
|
406 |
+
|
407 |
+
chatbot = gr.ChatInterface(
|
408 |
+
fn=chat_interface_with_reset,
|
409 |
+
chatbot=gr.Chatbot(
|
410 |
+
value=[{"role": "assistant", "content": get_welcome_message()}],
|
411 |
+
show_label=False,
|
412 |
+
type='messages'
|
413 |
+
),
|
414 |
+
textbox=gr.Textbox(
|
415 |
+
placeholder="Type your response here...",
|
416 |
+
show_label=False,
|
417 |
+
max_length=1000,
|
418 |
+
submit_btn=True
|
419 |
+
)
|
420 |
+
)
|
421 |
+
|
422 |
+
with gr.Row():
|
423 |
+
reset_btn = gr.Button("π Start New Assessment", variant="secondary")
|
424 |
+
|
425 |
+
reset_btn.click(
|
426 |
+
fn=reset_chat,
|
427 |
+
outputs=[chatbot.chatbot, chatbot.textbox],
|
428 |
+
show_progress=False
|
429 |
+
)
|
430 |
+
|
431 |
+
|
432 |
+
def check_groq_connection():
|
433 |
+
try:
|
434 |
+
from backend.utils import llm
|
435 |
+
test_response = llm.complete("Hello")
|
436 |
+
print("β
Groq connection successful")
|
437 |
+
return True
|
438 |
+
except Exception as e:
|
439 |
+
print(f"β Groq connection failed: {e}")
|
440 |
+
return False
|
441 |
+
|
442 |
+
|
443 |
+
def refresh_page():
|
444 |
+
"""Force a complete page refresh"""
|
445 |
+
return None
|
446 |
+
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
print("π Starting GraviLog Pregnancy Risk Assessment Agent...")
|
451 |
+
check_groq_connection()
|
452 |
+
|
453 |
+
|
454 |
+
is_hf_space = os.getenv('SPACE_ID') is not None
|
455 |
+
|
456 |
+
if is_hf_space:
|
457 |
+
print("π Running on Hugging Face Spaces")
|
458 |
+
print("π Each page refresh will start a new conversation")
|
459 |
+
demo.queue().launch(
|
460 |
+
server_name="0.0.0.0",
|
461 |
+
server_port=7860,
|
462 |
+
share=False,
|
463 |
+
debug=False
|
464 |
+
)
|
465 |
+
else:
|
466 |
+
print("π Running locally")
|
467 |
+
print("π Using Groq API for LLM processing")
|
468 |
+
print("π Make sure your GROQ_API_KEY is set in environment variables")
|
469 |
+
print("π Make sure your Pinecone index is set up and populated")
|
470 |
+
|
471 |
+
demo.queue().launch(
|
472 |
+
server_name="0.0.0.0",
|
473 |
+
server_port=7860,
|
474 |
+
share=True,
|
475 |
+
debug=True,
|
476 |
+
show_error=True
|
477 |
+
)
|
rag_functions.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import Stemmer
|
3 |
+
import requests
|
4 |
+
from utils import get_and_chunk_documents, llm, embed_model, get_index
|
5 |
+
from utils import Settings
|
6 |
+
from llama_index.retrievers.bm25 import BM25Retriever
|
7 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
8 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
9 |
+
from llama_index.core.response_synthesizers import get_response_synthesizer
|
10 |
+
from llama_index.core.settings import Settings
|
11 |
+
from llama_index.core import VectorStoreIndex
|
12 |
+
from llama_index.core.llms import ChatMessage
|
13 |
+
from llama_index.core.retrievers import QueryFusionRetriever
|
14 |
+
import json
|
15 |
+
|
16 |
+
|
17 |
+
Settings.llm = llm
|
18 |
+
Settings.embed_model = embed_model
|
19 |
+
|
20 |
+
|
21 |
+
index = get_index()
|
22 |
+
hybrid_retriever = None
|
23 |
+
vector_retriever = None
|
24 |
+
bm25_retriever = None
|
25 |
+
|
26 |
+
if index:
|
27 |
+
try:
|
28 |
+
|
29 |
+
vector_retriever = index.as_retriever(similarity_top_k=15)
|
30 |
+
print("β
Vector retriever initialized successfully")
|
31 |
+
|
32 |
+
|
33 |
+
all_nodes = index.docstore.docs
|
34 |
+
if len(all_nodes) == 0:
|
35 |
+
print("β οΈ Warning: No documents found in index, skipping BM25 retriever")
|
36 |
+
hybrid_retriever = vector_retriever
|
37 |
+
else:
|
38 |
+
|
39 |
+
has_text_content = False
|
40 |
+
for node_id, node in all_nodes.items():
|
41 |
+
if hasattr(node, 'text') and node.text and node.text.strip():
|
42 |
+
has_text_content = True
|
43 |
+
break
|
44 |
+
|
45 |
+
if not has_text_content:
|
46 |
+
print("β οΈ Warning: No text content found in documents, skipping BM25 retriever")
|
47 |
+
hybrid_retriever = vector_retriever
|
48 |
+
else:
|
49 |
+
try:
|
50 |
+
|
51 |
+
print("π Creating BM25 retriever...")
|
52 |
+
bm25_retriever = BM25Retriever.from_defaults(
|
53 |
+
docstore=index.docstore,
|
54 |
+
similarity_top_k=15,
|
55 |
+
verbose=False
|
56 |
+
)
|
57 |
+
print("β
BM25 retriever initialized successfully")
|
58 |
+
|
59 |
+
|
60 |
+
hybrid_retriever = QueryFusionRetriever(
|
61 |
+
retrievers=[vector_retriever, bm25_retriever],
|
62 |
+
similarity_top_k=20,
|
63 |
+
num_queries=1,
|
64 |
+
mode="reciprocal_rerank",
|
65 |
+
use_async=False,
|
66 |
+
)
|
67 |
+
print("β
Hybrid retriever initialized successfully")
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
print(f"β Warning: Could not initialize BM25 retriever: {e}")
|
71 |
+
print("π Falling back to vector-only retrieval")
|
72 |
+
hybrid_retriever = vector_retriever
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
print(f"β Warning: Could not initialize retrievers: {e}")
|
76 |
+
hybrid_retriever = None
|
77 |
+
vector_retriever = None
|
78 |
+
bm25_retriever = None
|
79 |
+
else:
|
80 |
+
print("β Warning: Could not initialize retrievers - index is None")
|
81 |
+
|
82 |
+
def call_groq_api(prompt):
|
83 |
+
"""Call Groq API instead of LM Studio"""
|
84 |
+
try:
|
85 |
+
|
86 |
+
response = Settings.llm.complete(prompt)
|
87 |
+
return str(response)
|
88 |
+
except Exception as e:
|
89 |
+
print(f"β Groq API call failed: {e}")
|
90 |
+
raise e
|
91 |
+
|
92 |
+
def get_direct_answer(question, symptom_summary, conversation_context="", max_context_nodes=8, is_risk_assessment=True):
|
93 |
+
"""Get answer using hybrid retriever with retrieved context"""
|
94 |
+
|
95 |
+
print(f"π― Processing question: {question}")
|
96 |
+
|
97 |
+
if not hybrid_retriever:
|
98 |
+
return "Error: Retriever not available. Please check if documents are properly loaded in the index."
|
99 |
+
|
100 |
+
try:
|
101 |
+
|
102 |
+
print("π Retrieving with available retrieval method...")
|
103 |
+
retrieved_nodes = hybrid_retriever.retrieve(question)
|
104 |
+
print(f"π Retrieved {len(retrieved_nodes)} nodes")
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(f"β Retrieval failed: {e}")
|
108 |
+
return f"Error during document retrieval: {e}. Please check your document index."
|
109 |
+
|
110 |
+
if not retrieved_nodes:
|
111 |
+
return "No relevant documents found for this question. Please ensure your medical knowledge base is properly loaded and consult your healthcare provider for medical advice."
|
112 |
+
|
113 |
+
|
114 |
+
try:
|
115 |
+
reranker = SentenceTransformerRerank(
|
116 |
+
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
|
117 |
+
top_n=max_context_nodes,
|
118 |
+
)
|
119 |
+
|
120 |
+
reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_str=question)
|
121 |
+
print(f"π― After reranking: {len(reranked_nodes)} nodes")
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
print(f"β Reranking failed: {e}, using original nodes")
|
125 |
+
reranked_nodes = retrieved_nodes[:max_context_nodes]
|
126 |
+
|
127 |
+
|
128 |
+
filtered_nodes = []
|
129 |
+
pregnancy_keywords = ['pregnancy', 'preeclampsia', 'gestational', 'trimester', 'fetal', 'bleeding', 'contractions', 'prenatal']
|
130 |
+
|
131 |
+
for node in reranked_nodes:
|
132 |
+
node_text = node.get_text().lower()
|
133 |
+
if any(keyword in node_text for keyword in pregnancy_keywords):
|
134 |
+
filtered_nodes.append(node)
|
135 |
+
|
136 |
+
if filtered_nodes:
|
137 |
+
reranked_nodes = filtered_nodes[:max_context_nodes]
|
138 |
+
print(f"π After pregnancy keyword filtering: {len(reranked_nodes)} nodes")
|
139 |
+
else:
|
140 |
+
print("β οΈ No pregnancy-related content found, using original nodes")
|
141 |
+
|
142 |
+
|
143 |
+
context_chunks = []
|
144 |
+
total_chars = 0
|
145 |
+
max_context_chars = 6000
|
146 |
+
|
147 |
+
for node in reranked_nodes:
|
148 |
+
node_text = node.get_text()
|
149 |
+
if total_chars + len(node_text) <= max_context_chars:
|
150 |
+
context_chunks.append(node_text)
|
151 |
+
total_chars += len(node_text)
|
152 |
+
else:
|
153 |
+
remaining_chars = max_context_chars - total_chars
|
154 |
+
if remaining_chars > 100:
|
155 |
+
context_chunks.append(node_text[:remaining_chars] + "...")
|
156 |
+
break
|
157 |
+
|
158 |
+
context_text = "\n\n---\n\n".join(context_chunks)
|
159 |
+
|
160 |
+
|
161 |
+
if is_risk_assessment:
|
162 |
+
prompt = f"""You are the GraviLog Pregnancy Risk Assessment Agent. Use ONLY the context belowβdo not invent or add any new medical facts.
|
163 |
+
|
164 |
+
SYMPTOM RESPONSES:
|
165 |
+
{symptom_summary}
|
166 |
+
|
167 |
+
MEDICAL KNOWLEDGE:
|
168 |
+
{context_text}
|
169 |
+
|
170 |
+
Respond ONLY in this exact format (no extra text):
|
171 |
+
|
172 |
+
π₯ Risk Assessment Complete
|
173 |
+
**Risk Level:** <Low/Medium/High>
|
174 |
+
**Recommended Action:** <from KB's Risk Output Labels>
|
175 |
+
|
176 |
+
π¬ Rationale:
|
177 |
+
<One or two sentences citing which bullet(s) from the KB triggered your risk level.>"""
|
178 |
+
|
179 |
+
else:
|
180 |
+
|
181 |
+
prompt = f"""You are a pregnancy health assistant. Based on the medical knowledge below, answer the user's question about pregnancy symptoms and conditions.
|
182 |
+
|
183 |
+
USER QUESTION: {question}
|
184 |
+
|
185 |
+
CONVERSATION CONTEXT:
|
186 |
+
{conversation_context}
|
187 |
+
|
188 |
+
CURRENT SYMPTOMS REPORTED:
|
189 |
+
{symptom_summary}
|
190 |
+
|
191 |
+
MEDICAL KNOWLEDGE:
|
192 |
+
{context_text}
|
193 |
+
|
194 |
+
Provide a clear, informative answer based on the medical knowledge. Always mention if symptoms require medical attention and provide risk level (Low/Medium/High) when relevant."""
|
195 |
+
|
196 |
+
try:
|
197 |
+
print("π€ Generating response with Groq API...")
|
198 |
+
response_text = call_groq_api(prompt)
|
199 |
+
return response_text
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
print(f"β LLM response failed: {e}")
|
203 |
+
import traceback
|
204 |
+
traceback.print_exc()
|
205 |
+
return f"Error generating response: {e}"
|
206 |
+
|
207 |
+
def get_answer_with_query_engine(question):
|
208 |
+
"""Alternative approach using LlamaIndex query engine with hybrid retrieval"""
|
209 |
+
try:
|
210 |
+
print(f"π― Processing question with query engine: {question}")
|
211 |
+
|
212 |
+
if index is None:
|
213 |
+
return "Error: Could not load index"
|
214 |
+
|
215 |
+
|
216 |
+
if hybrid_retriever:
|
217 |
+
query_engine = RetrieverQueryEngine.from_args(
|
218 |
+
retriever=hybrid_retriever,
|
219 |
+
response_synthesizer=get_response_synthesizer(
|
220 |
+
response_mode="compact",
|
221 |
+
use_async=False
|
222 |
+
),
|
223 |
+
node_postprocessors=[
|
224 |
+
SentenceTransformerRerank(
|
225 |
+
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
|
226 |
+
top_n=5
|
227 |
+
)
|
228 |
+
]
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
|
232 |
+
query_engine = index.as_query_engine(
|
233 |
+
similarity_top_k=10,
|
234 |
+
response_mode="compact"
|
235 |
+
)
|
236 |
+
|
237 |
+
print("π€ Querying with engine...")
|
238 |
+
response = query_engine.query(question)
|
239 |
+
|
240 |
+
return str(response)
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(f"β Query engine failed: {e}")
|
244 |
+
import traceback
|
245 |
+
traceback.print_exc()
|
246 |
+
return f"Error with query engine: {e}. Please check your setup and try again."
|
requirements.txt
ADDED
Binary file (8.69 kB). View file
|
|
utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from pinecone import Pinecone, ServerlessSpec
|
4 |
+
from llama_index.core import (SimpleDirectoryReader,Document, VectorStoreIndex, StorageContext, load_index_from_storage)
|
5 |
+
from llama_index.core.node_parser import SemanticSplitterNodeParser
|
6 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
7 |
+
from llama_index.readers.file import CSVReader
|
8 |
+
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
9 |
+
from llama_index.core.settings import Settings
|
10 |
+
from llama_index.llms.groq import Groq
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
18 |
+
llm = Groq(
|
19 |
+
model="llama-3.1-8b-instant",
|
20 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
21 |
+
max_tokens=500,
|
22 |
+
temperature=0.1
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
Settings.embed_model = embed_model
|
27 |
+
Settings.llm = llm
|
28 |
+
|
29 |
+
|
30 |
+
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
31 |
+
index_name = os.getenv("PINECONE_INDEX")
|
32 |
+
|
33 |
+
def get_vector_store():
|
34 |
+
|
35 |
+
pinecone_index = pc.Index(index_name)
|
36 |
+
return PineconeVectorStore(pinecone_index=pinecone_index)
|
37 |
+
|
38 |
+
def get_storage_context(for_rebuild=False):
|
39 |
+
|
40 |
+
vector_store = get_vector_store()
|
41 |
+
persist_dir = "./storage"
|
42 |
+
|
43 |
+
if for_rebuild or not os.path.exists(persist_dir):
|
44 |
+
|
45 |
+
return StorageContext.from_defaults(vector_store=vector_store)
|
46 |
+
else:
|
47 |
+
|
48 |
+
return StorageContext.from_defaults(
|
49 |
+
vector_store=vector_store,
|
50 |
+
persist_dir=persist_dir
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def get_and_chunk_documents():
|
58 |
+
|
59 |
+
try:
|
60 |
+
|
61 |
+
file_extractor = {".csv": CSVReader()}
|
62 |
+
|
63 |
+
|
64 |
+
documents = SimpleDirectoryReader(
|
65 |
+
"../knowledge_base",
|
66 |
+
file_extractor=file_extractor
|
67 |
+
).load_data()
|
68 |
+
|
69 |
+
print(f"π Loaded {len(documents)} documents")
|
70 |
+
|
71 |
+
node_parser = SemanticSplitterNodeParser(
|
72 |
+
buffer_size=1,
|
73 |
+
breakpoint_percentile_threshold=95,
|
74 |
+
embed_model=embed_model
|
75 |
+
)
|
76 |
+
|
77 |
+
nodes = node_parser.get_nodes_from_documents(documents)
|
78 |
+
print(f"π Created {len(nodes)} document chunks")
|
79 |
+
return nodes
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
print(f"β Error loading documents: {e}")
|
83 |
+
return []
|
84 |
+
|
85 |
+
|
86 |
+
def get_index():
|
87 |
+
|
88 |
+
try:
|
89 |
+
storage_context = get_storage_context()
|
90 |
+
|
91 |
+
return load_index_from_storage(storage_context)
|
92 |
+
except Exception as e:
|
93 |
+
print(f"β οΈ Local storage not found, creating index from existing Pinecone data...")
|
94 |
+
try:
|
95 |
+
|
96 |
+
vector_store = get_vector_store()
|
97 |
+
storage_context = get_storage_context()
|
98 |
+
index = VectorStoreIndex.from_vector_store(
|
99 |
+
vector_store=vector_store,
|
100 |
+
storage_context=storage_context
|
101 |
+
)
|
102 |
+
return index
|
103 |
+
except Exception as e2:
|
104 |
+
print(f"β Error creating index from vector store: {e2}")
|
105 |
+
return None
|
106 |
+
|
107 |
+
def check_index_status():
|
108 |
+
|
109 |
+
try:
|
110 |
+
pinecone_index = pc.Index(index_name)
|
111 |
+
stats = pinecone_index.describe_index_stats()
|
112 |
+
vector_count = stats.get('total_vector_count', 0)
|
113 |
+
|
114 |
+
if vector_count > 0:
|
115 |
+
print(f"β
Index found with {vector_count} vectors")
|
116 |
+
return True
|
117 |
+
else:
|
118 |
+
print("β Index exists but is empty")
|
119 |
+
return False
|
120 |
+
except Exception as e:
|
121 |
+
print(f"β Error checking index: {e}")
|
122 |
+
return False
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def clear_pinecone_index():
|
127 |
+
"""Delete all vectors from Pinecone index"""
|
128 |
+
try:
|
129 |
+
pinecone_index = pc.Index(index_name)
|
130 |
+
|
131 |
+
|
132 |
+
stats = pinecone_index.describe_index_stats()
|
133 |
+
vector_count = stats.get('total_vector_count', 0)
|
134 |
+
print(f"ποΈ Current vectors in index: {vector_count}")
|
135 |
+
|
136 |
+
if vector_count > 0:
|
137 |
+
|
138 |
+
pinecone_index.delete(delete_all=True)
|
139 |
+
print("β
All vectors deleted from Pinecone index")
|
140 |
+
else:
|
141 |
+
print("βΉοΈ Index is already empty")
|
142 |
+
|
143 |
+
return True
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
print(f"β Error clearing index: {e}")
|
147 |
+
return False
|
148 |
+
|
149 |
+
def rebuild_index():
|
150 |
+
"""Clear old data and rebuild index with new CSV processing"""
|
151 |
+
try:
|
152 |
+
print("π Starting index rebuild process...")
|
153 |
+
|
154 |
+
|
155 |
+
if not clear_pinecone_index():
|
156 |
+
print("β Failed to clear index, aborting rebuild")
|
157 |
+
return None
|
158 |
+
|
159 |
+
|
160 |
+
import shutil
|
161 |
+
if os.path.exists("./storage"):
|
162 |
+
shutil.rmtree("./storage")
|
163 |
+
print("ποΈ Cleared local storage")
|
164 |
+
|
165 |
+
|
166 |
+
nodes = get_and_chunk_documents()
|
167 |
+
|
168 |
+
if not nodes:
|
169 |
+
print("β No nodes created, cannot rebuild index")
|
170 |
+
return None
|
171 |
+
|
172 |
+
|
173 |
+
storage_context = get_storage_context(for_rebuild=True)
|
174 |
+
index = VectorStoreIndex(nodes, storage_context=storage_context)
|
175 |
+
|
176 |
+
|
177 |
+
index.storage_context.persist(persist_dir="./storage")
|
178 |
+
|
179 |
+
print(f"β
Index rebuilt successfully with {len(nodes)} nodes")
|
180 |
+
return index
|
181 |
+
|
182 |
+
except Exception as e:
|
183 |
+
print(f"β Error rebuilding index: {e}")
|
184 |
+
return None
|