Affanp commited on
Commit
44a2e1d
Β·
0 Parent(s):

Initial commit - Pregnancy RAG Chatbot

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +477 -0
  4. rag_functions.py +246 -0
  5. requirements.txt +0 -0
  6. utils.py +184 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pregnancy RAG Chatbot
3
+ emoji: πŸ“ˆ
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Pregnancy Risk Assessment AI Chatbot
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+ import traceback
6
+
7
+
8
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9
+
10
+
11
+ from rag_functions import get_direct_answer, get_answer_with_query_engine
12
+ from utils import get_index
13
+ print("βœ… Successfully imported RAG functions")
14
+
15
+ class PregnancyRiskAgent:
16
+ def __init__(self):
17
+ self.conversation_history = []
18
+ self.current_symptoms = {}
19
+ self.risk_assessment_done = False
20
+ self.user_context = {}
21
+ self.last_user_query = ""
22
+
23
+
24
+ self.symptom_questions = [
25
+ "Are you currently experiencing any unusual bleeding or discharge?",
26
+ "How would you describe your baby's movements today compared to yesterday?",
27
+ "Have you had any headaches that won't go away or that affect your vision?",
28
+ "Do you feel any pressure or pain in your pelvis or lower back?",
29
+ "Are you experiencing any other symptoms? (If yes, please describe briefly)"
30
+ ]
31
+
32
+ self.current_question_index = 0
33
+ self.waiting_for_first_response = True
34
+
35
+ def add_to_conversation_history(self, role, message):
36
+ self.conversation_history.append({
37
+ "role": role,
38
+ "message": message,
39
+ "timestamp": datetime.now().isoformat()
40
+ })
41
+
42
+
43
+ if len(self.conversation_history) > 20:
44
+ self.conversation_history = self.conversation_history[-20:]
45
+
46
+ def get_conversation_context(self):
47
+ context_parts = []
48
+
49
+ recent_history = self.conversation_history[-10:]
50
+
51
+ for entry in recent_history:
52
+ if entry["role"] == "user":
53
+ context_parts.append(f"User: {entry['message']}")
54
+ else:
55
+ context_parts.append(f"Assistant: {entry['message'][:200]}...")
56
+
57
+ return "\n".join(context_parts)
58
+
59
+ def is_follow_up_question(self, user_input):
60
+ follow_up_indicators = [
61
+ "what about", "can you explain", "what does", "why", "how",
62
+ "tell me more", "what should i", "is it normal", "should i be worried",
63
+ "what if", "when should", "how long", "what causes", "is this"
64
+ ]
65
+
66
+ user_lower = user_input.lower()
67
+ return any(indicator in user_lower for indicator in follow_up_indicators)
68
+
69
+ def process_user_input(self, user_input, chat_history):
70
+ try:
71
+ self.last_user_query = user_input
72
+ self.add_to_conversation_history("user", user_input)
73
+
74
+
75
+ if self.waiting_for_first_response:
76
+ self.current_symptoms[f"question_0"] = user_input
77
+ self.waiting_for_first_response = False
78
+ self.current_question_index = 1
79
+
80
+ if self.current_question_index < len(self.symptom_questions):
81
+ bot_response = f"{self.symptom_questions[self.current_question_index]}"
82
+ else:
83
+ bot_response = self.provide_risk_assessment()
84
+ self.risk_assessment_done = True
85
+
86
+ self.add_to_conversation_history("assistant", bot_response)
87
+ return bot_response
88
+
89
+
90
+ elif self.current_question_index < len(self.symptom_questions) and not self.risk_assessment_done:
91
+ self.current_symptoms[f"question_{self.current_question_index}"] = user_input
92
+ self.current_question_index += 1
93
+
94
+ if self.current_question_index < len(self.symptom_questions):
95
+ bot_response = f"{self.symptom_questions[self.current_question_index]}"
96
+ else:
97
+ bot_response = self.provide_risk_assessment()
98
+ self.risk_assessment_done = True
99
+
100
+ self.add_to_conversation_history("assistant", bot_response)
101
+ return bot_response
102
+
103
+
104
+ else:
105
+ bot_response = self.handle_follow_up_conversation(user_input)
106
+ self.add_to_conversation_history("assistant", bot_response)
107
+ return bot_response
108
+
109
+ except Exception as e:
110
+ print(f"❌ Error in process_user_input: {e}")
111
+ traceback.print_exc()
112
+ error_response = "I encountered an error. Please try again or consult your healthcare provider."
113
+ self.add_to_conversation_history("assistant", error_response)
114
+ return error_response
115
+
116
+ def handle_follow_up_conversation(self, user_input):
117
+ try:
118
+ print(f"πŸ” Processing follow-up question: {user_input}")
119
+
120
+ symptom_summary = self.create_symptom_summary()
121
+ conversation_context = self.get_conversation_context()
122
+
123
+ if any(word in user_input.lower() for word in ["last", "previous", "what did i ask", "my question"]):
124
+ if self.last_user_query:
125
+ return f"Your last question was: \"{self.last_user_query}\"\n\nWould you like me to elaborate on that topic or do you have a different question?"
126
+ else:
127
+ return "I don't have a record of your previous question. Could you please rephrase what you'd like to know?"
128
+
129
+ rag_response = get_direct_answer(user_input, symptom_summary, conversation_context=conversation_context, is_risk_assessment=False)
130
+
131
+ if "Error" in rag_response or len(rag_response) < 50:
132
+ print("πŸ”„ Trying alternative method...")
133
+ rag_response = get_answer_with_query_engine(user_input)
134
+
135
+ bot_response = f"""Based on your symptoms and medical literature:
136
+
137
+ {rag_response}"""
138
+
139
+ return bot_response
140
+
141
+ except Exception as e:
142
+ print(f"❌ Error in follow-up conversation: {e}")
143
+ return "I encountered an error processing your question. Could you please rephrase it or consult your healthcare provider?"
144
+
145
+ def create_symptom_summary(self):
146
+ if not self.current_symptoms:
147
+ return "No specific symptoms reported yet"
148
+
149
+ summary_parts = []
150
+ for i, (key, response) in enumerate(self.current_symptoms.items()):
151
+ if i < len(self.symptom_questions):
152
+ question = self.symptom_questions[i]
153
+ summary_parts.append(f"{question}: {response}")
154
+ return "\n".join(summary_parts)
155
+
156
+ def parse_risk_level(self, text):
157
+ import re
158
+
159
+ patterns = [
160
+ r'\*\*Risk Level:\*\*\s*(Low|Medium|High)',
161
+ r'Risk Level:\s*\*\*(Low|Medium|High)\*\*',
162
+ r'Risk Level:\s*(Low|Medium|High)',
163
+ r'\*\*Risk Level:\*\*\s*<(Low|Medium|High)>',
164
+ r'Risk Level.*?<(Low|Medium|High)>',
165
+ ]
166
+
167
+ for pattern in patterns:
168
+ match = re.search(pattern, text, re.IGNORECASE)
169
+ if match:
170
+ risk_level = match.group(1).capitalize()
171
+ print(f"βœ… Successfully parsed risk level: {risk_level}")
172
+ return risk_level
173
+
174
+ print(f"❌ Could not parse risk level from: {text[:200]}...")
175
+ return None
176
+
177
+ def provide_risk_assessment(self):
178
+ all_symptoms = self.create_symptom_summary()
179
+
180
+ rag_query = f"Analyze these pregnancy symptoms for risk assessment:\n{all_symptoms}\n\nProvide risk level and medical recommendations."
181
+ detailed_analysis = get_direct_answer(rag_query, all_symptoms, is_risk_assessment=True)
182
+
183
+ print(f"πŸ” RAG Response: {detailed_analysis[:300]}...")
184
+
185
+ llm_risk_level = self.parse_risk_level(detailed_analysis)
186
+
187
+ if llm_risk_level:
188
+ risk_level = llm_risk_level
189
+
190
+ if risk_level == "Low":
191
+ action = "βœ… Continue routine prenatal care and self-monitoring"
192
+ elif risk_level == "Medium":
193
+ action = "⚠️ Contact your doctor within 24 hours"
194
+ elif risk_level == "High":
195
+ action = "🚨 Immediate visit to ER or OB emergency care required"
196
+ else:
197
+ print("⚠️ RAG assessment failed, using fallback")
198
+ risk_level = "Medium"
199
+ action = "⚠️ Contact your doctor within 24 hours"
200
+
201
+ symptom_list = []
202
+ for i, (key, symptom) in enumerate(self.current_symptoms.items()):
203
+ question = self.symptom_questions[i] if i < len(self.symptom_questions) else f"Question {i+1}"
204
+ symptom_list.append(f"β€’ **{question}**: {symptom}")
205
+
206
+ assessment = f"""
207
+ ## πŸ₯ **Risk Assessment Complete**
208
+
209
+ **Risk Level: {risk_level}**
210
+ **Recommended Action: {action}**
211
+
212
+ ### πŸ“‹ **Your Reported Symptoms:**
213
+ {chr(10).join(symptom_list)}
214
+
215
+ ### πŸ”¬ **Medical Analysis:**
216
+ {detailed_analysis}
217
+
218
+ ### πŸ’‘ **Next Steps:**
219
+ - Follow the recommended action above
220
+ - Keep monitoring your symptoms
221
+ - Contact your healthcare provider if symptoms worsen
222
+ - Feel free to ask me any follow-up questions about pregnancy health
223
+
224
+ """
225
+ return assessment
226
+
227
+ def reset_conversation(self):
228
+ self.conversation_history = []
229
+ self.current_symptoms = {}
230
+ self.current_question_index = 0
231
+ self.risk_assessment_done = False
232
+ self.waiting_for_first_response = True
233
+ self.user_context = {}
234
+ self.last_user_query = ""
235
+ return get_welcome_message()
236
+
237
+ def get_welcome_message():
238
+ return """Hello! I'm here to help assess pregnancy-related symptoms and provide risk insights based on medical literature.
239
+
240
+ I'll ask you a few important questions about your current symptoms, then provide a risk assessment and recommendations. After that, feel free to ask any follow-up questions!
241
+
242
+ **To get started, please tell me:**
243
+ Are you currently experiencing any unusual bleeding or discharge?
244
+
245
+ ---
246
+ ⚠️ **Important**: This tool is for informational purposes only and should not replace professional medical care. In case of emergency, contact your healthcare provider immediately."""
247
+
248
+
249
+ def create_new_agent():
250
+
251
+ return PregnancyRiskAgent()
252
+
253
+
254
+ agent = create_new_agent()
255
+
256
+ def chat_interface_with_reset(user_input, history):
257
+ global agent
258
+
259
+ if user_input.lower() in ["reset", "restart", "new assessment"]:
260
+ agent = create_new_agent()
261
+ return get_welcome_message()
262
+
263
+ response = agent.process_user_input(user_input, history)
264
+ return response
265
+
266
+ def reset_chat():
267
+ global agent
268
+ agent = create_new_agent()
269
+ return [{"role": "assistant", "content": get_welcome_message()}], ""
270
+
271
+
272
+
273
+ custom_css = """
274
+ body, .gradio-container {
275
+ color: yellow !important;
276
+ }
277
+
278
+ .header {
279
+ background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 100%);
280
+ padding: 2rem;
281
+ border-radius: 1rem;
282
+ text-align: center;
283
+ margin-bottom: 2rem;
284
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
285
+ }
286
+
287
+ .header h1 {
288
+ color: black !important;
289
+ margin-bottom: 0.5rem;
290
+ font-size: 2.5rem;
291
+ }
292
+
293
+ .header p {
294
+ color: black !important;
295
+ font-size: 1.1rem;
296
+ margin: 0.5rem 0;
297
+ }
298
+
299
+ .warning {
300
+ background-color: #fff4e6;
301
+ border-left: 6px solid #ff7f00;
302
+ padding: 15px;
303
+ border-radius: 5px;
304
+ margin: 10px 0;
305
+ }
306
+
307
+ .warning h3 {
308
+ color: black !important;
309
+ margin-top: 0;
310
+ }
311
+
312
+ .warning p {
313
+ color: black !important;
314
+ line-height: 1.6;
315
+ }
316
+
317
+ div[style*="background-color: #e8f5e8"] {
318
+ color: black !important;
319
+ }
320
+
321
+ div[style*="background-color: #e8f5e8"] h3 {
322
+ color: black !important;
323
+ }
324
+
325
+ div[style*="background-color: #e8f5e8"] li {
326
+ color: black !important;
327
+ }
328
+
329
+ .chatbot {
330
+ color: black !important;
331
+ }
332
+
333
+ .message {
334
+ color: black !important;
335
+ }
336
+
337
+ /* Hide Gradio footer elements */
338
+ .footer {
339
+ display: none !important;
340
+ }
341
+
342
+ .gradio-container .footer {
343
+ display: none !important;
344
+ }
345
+
346
+ footer {
347
+ display: none !important;
348
+ }
349
+
350
+ .api-docs {
351
+ display: none !important;
352
+ }
353
+
354
+ .built-with {
355
+ display: none !important;
356
+ }
357
+
358
+ .gradio-container > .built-with {
359
+ display: none !important;
360
+ }
361
+
362
+ .settings {
363
+ display: none !important;
364
+ }
365
+
366
+ div[class*="footer"] {
367
+ display: none !important;
368
+ }
369
+
370
+ div[class*="built"] {
371
+ display: none !important;
372
+ }
373
+
374
+ *:contains("Built with Gradio") {
375
+ display: none !important;
376
+ }
377
+
378
+ *:contains("Use via API") {
379
+ display: none !important;
380
+ }
381
+
382
+ *:contains("Settings") {
383
+ display: none !important;
384
+ }
385
+ """
386
+
387
+
388
+ with gr.Blocks(css=custom_css) as demo:
389
+ gr.HTML("""
390
+ <div class="header">
391
+ <h1>🀱 Pregnancy RAG Chatbot</h1>
392
+ <p><strong style="color: black !important;">Proactive RAG-powered pregnancy risk management</strong></p>
393
+ </div>
394
+ """)
395
+
396
+ with gr.Row():
397
+ with gr.Column(scale=1):
398
+ gr.HTML("""
399
+ <div class="warning">
400
+ <h3>⚠️ Medical Disclaimer</h3>
401
+ <p>This AI assistant provides information based on medical literature but is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
402
+ <p><strong style="color: black !important;">In emergencies, call emergency services immediately.</strong></p>
403
+ </div>
404
+ """)
405
+
406
+
407
+ chatbot = gr.ChatInterface(
408
+ fn=chat_interface_with_reset,
409
+ chatbot=gr.Chatbot(
410
+ value=[{"role": "assistant", "content": get_welcome_message()}],
411
+ show_label=False,
412
+ type='messages'
413
+ ),
414
+ textbox=gr.Textbox(
415
+ placeholder="Type your response here...",
416
+ show_label=False,
417
+ max_length=1000,
418
+ submit_btn=True
419
+ )
420
+ )
421
+
422
+ with gr.Row():
423
+ reset_btn = gr.Button("πŸ”„ Start New Assessment", variant="secondary")
424
+
425
+ reset_btn.click(
426
+ fn=reset_chat,
427
+ outputs=[chatbot.chatbot, chatbot.textbox],
428
+ show_progress=False
429
+ )
430
+
431
+
432
+ def check_groq_connection():
433
+ try:
434
+ from backend.utils import llm
435
+ test_response = llm.complete("Hello")
436
+ print("βœ… Groq connection successful")
437
+ return True
438
+ except Exception as e:
439
+ print(f"❌ Groq connection failed: {e}")
440
+ return False
441
+
442
+
443
+ def refresh_page():
444
+ """Force a complete page refresh"""
445
+ return None
446
+
447
+
448
+
449
+ if __name__ == "__main__":
450
+ print("πŸš€ Starting GraviLog Pregnancy Risk Assessment Agent...")
451
+ check_groq_connection()
452
+
453
+
454
+ is_hf_space = os.getenv('SPACE_ID') is not None
455
+
456
+ if is_hf_space:
457
+ print("πŸ“ Running on Hugging Face Spaces")
458
+ print("πŸ“ Each page refresh will start a new conversation")
459
+ demo.queue().launch(
460
+ server_name="0.0.0.0",
461
+ server_port=7860,
462
+ share=False,
463
+ debug=False
464
+ )
465
+ else:
466
+ print("πŸ“ Running locally")
467
+ print("πŸ“ Using Groq API for LLM processing")
468
+ print("πŸ“ Make sure your GROQ_API_KEY is set in environment variables")
469
+ print("πŸ“ Make sure your Pinecone index is set up and populated")
470
+
471
+ demo.queue().launch(
472
+ server_name="0.0.0.0",
473
+ server_port=7860,
474
+ share=True,
475
+ debug=True,
476
+ show_error=True
477
+ )
rag_functions.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import Stemmer
3
+ import requests
4
+ from utils import get_and_chunk_documents, llm, embed_model, get_index
5
+ from utils import Settings
6
+ from llama_index.retrievers.bm25 import BM25Retriever
7
+ from llama_index.core.postprocessor import SentenceTransformerRerank
8
+ from llama_index.core.query_engine import RetrieverQueryEngine
9
+ from llama_index.core.response_synthesizers import get_response_synthesizer
10
+ from llama_index.core.settings import Settings
11
+ from llama_index.core import VectorStoreIndex
12
+ from llama_index.core.llms import ChatMessage
13
+ from llama_index.core.retrievers import QueryFusionRetriever
14
+ import json
15
+
16
+
17
+ Settings.llm = llm
18
+ Settings.embed_model = embed_model
19
+
20
+
21
+ index = get_index()
22
+ hybrid_retriever = None
23
+ vector_retriever = None
24
+ bm25_retriever = None
25
+
26
+ if index:
27
+ try:
28
+
29
+ vector_retriever = index.as_retriever(similarity_top_k=15)
30
+ print("βœ… Vector retriever initialized successfully")
31
+
32
+
33
+ all_nodes = index.docstore.docs
34
+ if len(all_nodes) == 0:
35
+ print("⚠️ Warning: No documents found in index, skipping BM25 retriever")
36
+ hybrid_retriever = vector_retriever
37
+ else:
38
+
39
+ has_text_content = False
40
+ for node_id, node in all_nodes.items():
41
+ if hasattr(node, 'text') and node.text and node.text.strip():
42
+ has_text_content = True
43
+ break
44
+
45
+ if not has_text_content:
46
+ print("⚠️ Warning: No text content found in documents, skipping BM25 retriever")
47
+ hybrid_retriever = vector_retriever
48
+ else:
49
+ try:
50
+
51
+ print("πŸ”„ Creating BM25 retriever...")
52
+ bm25_retriever = BM25Retriever.from_defaults(
53
+ docstore=index.docstore,
54
+ similarity_top_k=15,
55
+ verbose=False
56
+ )
57
+ print("βœ… BM25 retriever initialized successfully")
58
+
59
+
60
+ hybrid_retriever = QueryFusionRetriever(
61
+ retrievers=[vector_retriever, bm25_retriever],
62
+ similarity_top_k=20,
63
+ num_queries=1,
64
+ mode="reciprocal_rerank",
65
+ use_async=False,
66
+ )
67
+ print("βœ… Hybrid retriever initialized successfully")
68
+
69
+ except Exception as e:
70
+ print(f"❌ Warning: Could not initialize BM25 retriever: {e}")
71
+ print("πŸ”„ Falling back to vector-only retrieval")
72
+ hybrid_retriever = vector_retriever
73
+
74
+ except Exception as e:
75
+ print(f"❌ Warning: Could not initialize retrievers: {e}")
76
+ hybrid_retriever = None
77
+ vector_retriever = None
78
+ bm25_retriever = None
79
+ else:
80
+ print("❌ Warning: Could not initialize retrievers - index is None")
81
+
82
+ def call_groq_api(prompt):
83
+ """Call Groq API instead of LM Studio"""
84
+ try:
85
+
86
+ response = Settings.llm.complete(prompt)
87
+ return str(response)
88
+ except Exception as e:
89
+ print(f"❌ Groq API call failed: {e}")
90
+ raise e
91
+
92
+ def get_direct_answer(question, symptom_summary, conversation_context="", max_context_nodes=8, is_risk_assessment=True):
93
+ """Get answer using hybrid retriever with retrieved context"""
94
+
95
+ print(f"🎯 Processing question: {question}")
96
+
97
+ if not hybrid_retriever:
98
+ return "Error: Retriever not available. Please check if documents are properly loaded in the index."
99
+
100
+ try:
101
+
102
+ print("πŸ” Retrieving with available retrieval method...")
103
+ retrieved_nodes = hybrid_retriever.retrieve(question)
104
+ print(f"πŸ“Š Retrieved {len(retrieved_nodes)} nodes")
105
+
106
+ except Exception as e:
107
+ print(f"❌ Retrieval failed: {e}")
108
+ return f"Error during document retrieval: {e}. Please check your document index."
109
+
110
+ if not retrieved_nodes:
111
+ return "No relevant documents found for this question. Please ensure your medical knowledge base is properly loaded and consult your healthcare provider for medical advice."
112
+
113
+
114
+ try:
115
+ reranker = SentenceTransformerRerank(
116
+ model='cross-encoder/ms-marco-MiniLM-L-2-v2',
117
+ top_n=max_context_nodes,
118
+ )
119
+
120
+ reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_str=question)
121
+ print(f"🎯 After reranking: {len(reranked_nodes)} nodes")
122
+
123
+ except Exception as e:
124
+ print(f"❌ Reranking failed: {e}, using original nodes")
125
+ reranked_nodes = retrieved_nodes[:max_context_nodes]
126
+
127
+
128
+ filtered_nodes = []
129
+ pregnancy_keywords = ['pregnancy', 'preeclampsia', 'gestational', 'trimester', 'fetal', 'bleeding', 'contractions', 'prenatal']
130
+
131
+ for node in reranked_nodes:
132
+ node_text = node.get_text().lower()
133
+ if any(keyword in node_text for keyword in pregnancy_keywords):
134
+ filtered_nodes.append(node)
135
+
136
+ if filtered_nodes:
137
+ reranked_nodes = filtered_nodes[:max_context_nodes]
138
+ print(f"πŸ” After pregnancy keyword filtering: {len(reranked_nodes)} nodes")
139
+ else:
140
+ print("⚠️ No pregnancy-related content found, using original nodes")
141
+
142
+
143
+ context_chunks = []
144
+ total_chars = 0
145
+ max_context_chars = 6000
146
+
147
+ for node in reranked_nodes:
148
+ node_text = node.get_text()
149
+ if total_chars + len(node_text) <= max_context_chars:
150
+ context_chunks.append(node_text)
151
+ total_chars += len(node_text)
152
+ else:
153
+ remaining_chars = max_context_chars - total_chars
154
+ if remaining_chars > 100:
155
+ context_chunks.append(node_text[:remaining_chars] + "...")
156
+ break
157
+
158
+ context_text = "\n\n---\n\n".join(context_chunks)
159
+
160
+
161
+ if is_risk_assessment:
162
+ prompt = f"""You are the GraviLog Pregnancy Risk Assessment Agent. Use ONLY the context belowβ€”do not invent or add any new medical facts.
163
+
164
+ SYMPTOM RESPONSES:
165
+ {symptom_summary}
166
+
167
+ MEDICAL KNOWLEDGE:
168
+ {context_text}
169
+
170
+ Respond ONLY in this exact format (no extra text):
171
+
172
+ πŸ₯ Risk Assessment Complete
173
+ **Risk Level:** <Low/Medium/High>
174
+ **Recommended Action:** <from KB's Risk Output Labels>
175
+
176
+ πŸ”¬ Rationale:
177
+ <One or two sentences citing which bullet(s) from the KB triggered your risk level.>"""
178
+
179
+ else:
180
+
181
+ prompt = f"""You are a pregnancy health assistant. Based on the medical knowledge below, answer the user's question about pregnancy symptoms and conditions.
182
+
183
+ USER QUESTION: {question}
184
+
185
+ CONVERSATION CONTEXT:
186
+ {conversation_context}
187
+
188
+ CURRENT SYMPTOMS REPORTED:
189
+ {symptom_summary}
190
+
191
+ MEDICAL KNOWLEDGE:
192
+ {context_text}
193
+
194
+ Provide a clear, informative answer based on the medical knowledge. Always mention if symptoms require medical attention and provide risk level (Low/Medium/High) when relevant."""
195
+
196
+ try:
197
+ print("πŸ€– Generating response with Groq API...")
198
+ response_text = call_groq_api(prompt)
199
+ return response_text
200
+
201
+ except Exception as e:
202
+ print(f"❌ LLM response failed: {e}")
203
+ import traceback
204
+ traceback.print_exc()
205
+ return f"Error generating response: {e}"
206
+
207
+ def get_answer_with_query_engine(question):
208
+ """Alternative approach using LlamaIndex query engine with hybrid retrieval"""
209
+ try:
210
+ print(f"🎯 Processing question with query engine: {question}")
211
+
212
+ if index is None:
213
+ return "Error: Could not load index"
214
+
215
+
216
+ if hybrid_retriever:
217
+ query_engine = RetrieverQueryEngine.from_args(
218
+ retriever=hybrid_retriever,
219
+ response_synthesizer=get_response_synthesizer(
220
+ response_mode="compact",
221
+ use_async=False
222
+ ),
223
+ node_postprocessors=[
224
+ SentenceTransformerRerank(
225
+ model='cross-encoder/ms-marco-MiniLM-L-2-v2',
226
+ top_n=5
227
+ )
228
+ ]
229
+ )
230
+ else:
231
+
232
+ query_engine = index.as_query_engine(
233
+ similarity_top_k=10,
234
+ response_mode="compact"
235
+ )
236
+
237
+ print("πŸ€– Querying with engine...")
238
+ response = query_engine.query(question)
239
+
240
+ return str(response)
241
+
242
+ except Exception as e:
243
+ print(f"❌ Query engine failed: {e}")
244
+ import traceback
245
+ traceback.print_exc()
246
+ return f"Error with query engine: {e}. Please check your setup and try again."
requirements.txt ADDED
Binary file (8.69 kB). View file
 
utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from pinecone import Pinecone, ServerlessSpec
4
+ from llama_index.core import (SimpleDirectoryReader,Document, VectorStoreIndex, StorageContext, load_index_from_storage)
5
+ from llama_index.core.node_parser import SemanticSplitterNodeParser
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ from llama_index.readers.file import CSVReader
8
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
9
+ from llama_index.core.settings import Settings
10
+ from llama_index.llms.groq import Groq
11
+
12
+
13
+
14
+ load_dotenv()
15
+
16
+
17
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
18
+ llm = Groq(
19
+ model="llama-3.1-8b-instant",
20
+ api_key=os.getenv("GROQ_API_KEY"),
21
+ max_tokens=500,
22
+ temperature=0.1
23
+ )
24
+
25
+
26
+ Settings.embed_model = embed_model
27
+ Settings.llm = llm
28
+
29
+
30
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
31
+ index_name = os.getenv("PINECONE_INDEX")
32
+
33
+ def get_vector_store():
34
+
35
+ pinecone_index = pc.Index(index_name)
36
+ return PineconeVectorStore(pinecone_index=pinecone_index)
37
+
38
+ def get_storage_context(for_rebuild=False):
39
+
40
+ vector_store = get_vector_store()
41
+ persist_dir = "./storage"
42
+
43
+ if for_rebuild or not os.path.exists(persist_dir):
44
+
45
+ return StorageContext.from_defaults(vector_store=vector_store)
46
+ else:
47
+
48
+ return StorageContext.from_defaults(
49
+ vector_store=vector_store,
50
+ persist_dir=persist_dir
51
+ )
52
+
53
+
54
+
55
+
56
+
57
+ def get_and_chunk_documents():
58
+
59
+ try:
60
+
61
+ file_extractor = {".csv": CSVReader()}
62
+
63
+
64
+ documents = SimpleDirectoryReader(
65
+ "../knowledge_base",
66
+ file_extractor=file_extractor
67
+ ).load_data()
68
+
69
+ print(f"πŸ“– Loaded {len(documents)} documents")
70
+
71
+ node_parser = SemanticSplitterNodeParser(
72
+ buffer_size=1,
73
+ breakpoint_percentile_threshold=95,
74
+ embed_model=embed_model
75
+ )
76
+
77
+ nodes = node_parser.get_nodes_from_documents(documents)
78
+ print(f"πŸ“„ Created {len(nodes)} document chunks")
79
+ return nodes
80
+
81
+ except Exception as e:
82
+ print(f"❌ Error loading documents: {e}")
83
+ return []
84
+
85
+
86
+ def get_index():
87
+
88
+ try:
89
+ storage_context = get_storage_context()
90
+
91
+ return load_index_from_storage(storage_context)
92
+ except Exception as e:
93
+ print(f"⚠️ Local storage not found, creating index from existing Pinecone data...")
94
+ try:
95
+
96
+ vector_store = get_vector_store()
97
+ storage_context = get_storage_context()
98
+ index = VectorStoreIndex.from_vector_store(
99
+ vector_store=vector_store,
100
+ storage_context=storage_context
101
+ )
102
+ return index
103
+ except Exception as e2:
104
+ print(f"❌ Error creating index from vector store: {e2}")
105
+ return None
106
+
107
+ def check_index_status():
108
+
109
+ try:
110
+ pinecone_index = pc.Index(index_name)
111
+ stats = pinecone_index.describe_index_stats()
112
+ vector_count = stats.get('total_vector_count', 0)
113
+
114
+ if vector_count > 0:
115
+ print(f"βœ… Index found with {vector_count} vectors")
116
+ return True
117
+ else:
118
+ print("❌ Index exists but is empty")
119
+ return False
120
+ except Exception as e:
121
+ print(f"❌ Error checking index: {e}")
122
+ return False
123
+
124
+
125
+
126
+ def clear_pinecone_index():
127
+ """Delete all vectors from Pinecone index"""
128
+ try:
129
+ pinecone_index = pc.Index(index_name)
130
+
131
+
132
+ stats = pinecone_index.describe_index_stats()
133
+ vector_count = stats.get('total_vector_count', 0)
134
+ print(f"πŸ—‘οΈ Current vectors in index: {vector_count}")
135
+
136
+ if vector_count > 0:
137
+
138
+ pinecone_index.delete(delete_all=True)
139
+ print("βœ… All vectors deleted from Pinecone index")
140
+ else:
141
+ print("ℹ️ Index is already empty")
142
+
143
+ return True
144
+
145
+ except Exception as e:
146
+ print(f"❌ Error clearing index: {e}")
147
+ return False
148
+
149
+ def rebuild_index():
150
+ """Clear old data and rebuild index with new CSV processing"""
151
+ try:
152
+ print("πŸ”„ Starting index rebuild process...")
153
+
154
+
155
+ if not clear_pinecone_index():
156
+ print("❌ Failed to clear index, aborting rebuild")
157
+ return None
158
+
159
+
160
+ import shutil
161
+ if os.path.exists("./storage"):
162
+ shutil.rmtree("./storage")
163
+ print("πŸ—‘οΈ Cleared local storage")
164
+
165
+
166
+ nodes = get_and_chunk_documents()
167
+
168
+ if not nodes:
169
+ print("❌ No nodes created, cannot rebuild index")
170
+ return None
171
+
172
+
173
+ storage_context = get_storage_context(for_rebuild=True)
174
+ index = VectorStoreIndex(nodes, storage_context=storage_context)
175
+
176
+
177
+ index.storage_context.persist(persist_dir="./storage")
178
+
179
+ print(f"βœ… Index rebuilt successfully with {len(nodes)} nodes")
180
+ return index
181
+
182
+ except Exception as e:
183
+ print(f"❌ Error rebuilding index: {e}")
184
+ return None