Abhinav Gavireddi
commited on
Commit
·
6c61722
1
Parent(s):
c613bb1
[fix]: optimized the entire pipeline
Browse files- app.py +222 -340
- requirements.txt +6 -5
- src/__init__.py +44 -20
- src/ghm.py +2 -2
- src/gpp.py +69 -152
- src/qa.py +75 -74
- src/retriever.py +48 -96
- src/utils.py +45 -11
app.py
CHANGED
@@ -3,250 +3,170 @@ import streamlit as st
|
|
3 |
from datetime import datetime
|
4 |
import re
|
5 |
from werkzeug.utils import secure_filename
|
|
|
|
|
6 |
|
7 |
from src.gpp import GPP, GPPConfig
|
8 |
from src.qa import AnswerGenerator
|
9 |
|
10 |
-
# Check if we need to modify the AnswerGenerator class to accept conversation context
|
11 |
-
# If the original implementation doesn't support this, we'll create a wrapper
|
12 |
-
|
13 |
-
class ContextAwareAnswerGenerator:
|
14 |
-
"""Wrapper around AnswerGenerator to include conversation context"""
|
15 |
-
|
16 |
-
def __init__(self, chunks):
|
17 |
-
self.chunks = chunks
|
18 |
-
self.original_generator = AnswerGenerator(chunks)
|
19 |
-
|
20 |
-
def answer(self, question, conversation_context=None):
|
21 |
-
"""
|
22 |
-
Generate answer with conversation context
|
23 |
-
|
24 |
-
Args:
|
25 |
-
chunks: Document chunks to search
|
26 |
-
question: Current question
|
27 |
-
conversation_context: List of previous Q&A for context
|
28 |
-
|
29 |
-
Returns:
|
30 |
-
answer, supporting_chunks
|
31 |
-
"""
|
32 |
-
# If no conversation context or original implementation supports it directly
|
33 |
-
if conversation_context is None or len(conversation_context) <= 1:
|
34 |
-
return self.original_generator.answer(question)
|
35 |
-
|
36 |
-
# Otherwise, enhance the question with context
|
37 |
-
# Create a contextual prompt by summarizing previous exchanges
|
38 |
-
context_prompt = "Based on our conversation so far:\n"
|
39 |
-
|
40 |
-
# Include the last few exchanges (limiting to prevent context getting too large)
|
41 |
-
max_history = min(len(conversation_context) - 1, 4) # Last 4 exchanges maximum
|
42 |
-
for i in range(max(0, len(conversation_context) - max_history - 1), len(conversation_context) - 1, 2):
|
43 |
-
if i < len(conversation_context) and i+1 < len(conversation_context):
|
44 |
-
user_q = conversation_context[i]["content"]
|
45 |
-
assistant_a = conversation_context[i+1]["content"]
|
46 |
-
context_prompt += f"You were asked: '{user_q}'\n"
|
47 |
-
context_prompt += f"You answered: '{assistant_a}'\n"
|
48 |
-
|
49 |
-
context_prompt += f"\nNow answer this follow-up question: {question}"
|
50 |
-
|
51 |
-
# Use the enhanced prompt
|
52 |
-
return self.original_generator.answer(context_prompt)
|
53 |
-
|
54 |
# --- Page Configuration ---
|
55 |
st.set_page_config(
|
56 |
-
page_title="Document Intelligence
|
57 |
-
page_icon="
|
58 |
layout="wide"
|
59 |
)
|
60 |
|
61 |
# --- Session State Initialization ---
|
62 |
if 'chat_history' not in st.session_state:
|
63 |
-
st.session_state.chat_history = []
|
64 |
-
if '
|
65 |
-
st.session_state.
|
66 |
if "selected_chunks" not in st.session_state:
|
67 |
st.session_state.selected_chunks = []
|
68 |
-
if "conversation_context" not in st.session_state:
|
69 |
-
st.session_state.conversation_context = []
|
70 |
|
71 |
-
# --- Custom CSS for
|
72 |
st.markdown(
|
73 |
"""
|
74 |
<style>
|
75 |
-
/*
|
76 |
-
|
77 |
-
background-color: #
|
78 |
-
|
79 |
}
|
80 |
-
|
81 |
-
/*
|
82 |
-
.
|
83 |
-
|
|
|
|
|
|
|
84 |
}
|
85 |
|
86 |
-
/*
|
87 |
-
.
|
88 |
-
|
89 |
-
border-radius: 8px;
|
90 |
-
padding: 20px;
|
91 |
-
margin-bottom: 20px;
|
92 |
-
box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24);
|
93 |
}
|
94 |
-
|
95 |
-
/*
|
96 |
-
.
|
97 |
-
|
98 |
-
|
99 |
-
border-radius:
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
103 |
}
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
}
|
108 |
-
|
109 |
-
/*
|
110 |
-
.
|
111 |
-
|
112 |
-
|
|
|
113 |
}
|
114 |
-
|
115 |
-
/*
|
116 |
-
|
117 |
-
|
118 |
-
padding: 12px;
|
119 |
-
border-radius: 4px;
|
120 |
-
font-size: 14px;
|
121 |
}
|
122 |
-
|
123 |
-
/*
|
124 |
-
|
125 |
-
|
126 |
}
|
127 |
-
|
128 |
-
/*
|
129 |
-
.
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
}
|
132 |
-
|
133 |
-
/*
|
134 |
-
.
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
}
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
display: flex;
|
142 |
-
flex-direction: column;
|
143 |
-
gap: 12px;
|
144 |
-
margin-top: 20px;
|
145 |
-
margin-bottom: 20px;
|
146 |
}
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
}
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
155 |
}
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
159 |
}
|
160 |
|
161 |
-
|
162 |
-
|
163 |
border-radius: 18px;
|
164 |
-
|
165 |
-
|
|
|
|
|
166 |
}
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
170 |
color: white;
|
171 |
-
|
172 |
-
}
|
173 |
-
|
174 |
-
.assistant-message .message-content {
|
175 |
-
background-color: #f0f2f6;
|
176 |
-
color: #1e1e1e;
|
177 |
-
border-bottom-left-radius: 4px;
|
178 |
}
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
}
|
184 |
-
|
185 |
-
/* Empty chat placeholder
|
186 |
.empty-chat-placeholder {
|
|
|
187 |
display: flex;
|
188 |
flex-direction: column;
|
189 |
-
align-items: center;
|
190 |
justify-content: center;
|
191 |
-
height: 300px;
|
192 |
-
background-color: #f8f9fa;
|
193 |
-
border-radius: 8px;
|
194 |
-
margin-bottom: 20px;
|
195 |
-
text-align: center;
|
196 |
-
color: #6c757d;
|
197 |
-
}
|
198 |
-
|
199 |
-
.empty-chat-icon {
|
200 |
-
font-size: 40px;
|
201 |
-
margin-bottom: 16px;
|
202 |
-
color: #adb5bd;
|
203 |
-
}
|
204 |
-
|
205 |
-
/* Message typing indicator */
|
206 |
-
.typing-indicator {
|
207 |
-
display: flex;
|
208 |
align-items: center;
|
209 |
-
|
210 |
-
margin-top: 8px;
|
211 |
-
}
|
212 |
-
|
213 |
-
.typing-indicator span {
|
214 |
-
height: 8px;
|
215 |
-
width: 8px;
|
216 |
-
background-color: #4361ee;
|
217 |
-
border-radius: 50%;
|
218 |
-
margin: 0 2px;
|
219 |
-
display: inline-block;
|
220 |
-
opacity: 0.7;
|
221 |
-
}
|
222 |
-
|
223 |
-
.typing-indicator span:nth-child(1) {
|
224 |
-
animation: pulse 1s infinite;
|
225 |
-
}
|
226 |
-
|
227 |
-
.typing-indicator span:nth-child(2) {
|
228 |
-
animation: pulse 1s infinite 0.2s;
|
229 |
-
}
|
230 |
-
|
231 |
-
.typing-indicator span:nth-child(3) {
|
232 |
-
animation: pulse 1s infinite 0.4s;
|
233 |
-
}
|
234 |
-
|
235 |
-
@keyframes pulse {
|
236 |
-
0% { transform: scale(1); opacity: 0.7; }
|
237 |
-
50% { transform: scale(1.2); opacity: 1; }
|
238 |
-
100% { transform: scale(1); opacity: 0.7; }
|
239 |
}
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
}
|
245 |
|
246 |
-
/* Info box */
|
247 |
-
.stAlert {
|
248 |
-
border-radius: 8px;
|
249 |
-
}
|
250 |
</style>
|
251 |
""", unsafe_allow_html=True
|
252 |
)
|
@@ -256,17 +176,10 @@ with st.sidebar:
|
|
256 |
# App info section
|
257 |
st.image("https://img.icons8.com/ios-filled/50/4A90E2/document.png", width=40)
|
258 |
st.title("Document Intelligence")
|
259 |
-
st.caption(f"Last updated: {datetime.now().strftime('%Y-%m-%d')}")
|
260 |
|
261 |
with st.expander("How It Works", expanded=True):
|
262 |
-
st.markdown(
|
263 |
-
"""
|
264 |
-
1. **Upload PDF**: Select and parse your document
|
265 |
-
2. **Ask Questions**: Type your query about the document
|
266 |
-
3. **Get Answers**: AI analyzes and responds with insights
|
267 |
-
4. **View Evidence**: See supporting chunks in the right sidebar
|
268 |
-
"""
|
269 |
-
)
|
270 |
|
271 |
st.markdown("---")
|
272 |
|
@@ -275,54 +188,54 @@ with st.sidebar:
|
|
275 |
uploaded_file = st.file_uploader("Select a PDF", type=["pdf"], help="Upload a PDF file to analyze")
|
276 |
|
277 |
if uploaded_file:
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
except Exception as e:
|
303 |
-
st.error(f"Parsing failed: {str(e)}")
|
304 |
-
st.session_state.parsed = None
|
305 |
-
with col2:
|
306 |
-
if st.button("Clear", use_container_width=True, key="clear_button"):
|
307 |
-
st.session_state.parsed = None
|
308 |
-
st.session_state.selected_chunks = []
|
309 |
-
st.session_state.chat_history = []
|
310 |
-
st.session_state.conversation_context = []
|
311 |
-
st.experimental_rerun()
|
312 |
-
except Exception as e:
|
313 |
-
st.error(f"Upload error: {str(e)}")
|
314 |
-
|
315 |
# Display document preview if parsed
|
316 |
-
if st.session_state.
|
317 |
st.markdown("---")
|
318 |
st.subheader("Document Preview")
|
319 |
-
parsed = st.session_state.
|
320 |
|
321 |
# Layout PDF
|
322 |
layout_pdf = parsed.get("layout_pdf")
|
323 |
if layout_pdf and os.path.exists(layout_pdf):
|
324 |
with st.expander("View Layout PDF", expanded=False):
|
325 |
st.markdown(f"[Open in new tab]({layout_pdf})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
# Content preview
|
328 |
md_path = parsed.get("md_path")
|
@@ -335,123 +248,92 @@ with st.sidebar:
|
|
335 |
except Exception as e:
|
336 |
st.warning(f"Could not preview content: {str(e)}")
|
337 |
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
st.title("Document Q&A")
|
345 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
346 |
|
347 |
-
|
348 |
-
|
|
|
349 |
else:
|
350 |
-
#
|
351 |
-
st.markdown("<div class='
|
352 |
-
question = st.text_input(
|
353 |
-
"Ask a question about your document:",
|
354 |
-
key="question_input",
|
355 |
-
placeholder="E.g., 'What are the key findings?' or 'Summarize the data'",
|
356 |
-
on_change=None # Ensure the input field gets cleared naturally after submission
|
357 |
-
)
|
358 |
-
|
359 |
-
col_btn1, col_btn2 = st.columns([4, 1])
|
360 |
-
with col_btn1:
|
361 |
-
submit_button = st.button("Get Answer", use_container_width=True)
|
362 |
-
with col_btn2:
|
363 |
-
clear_chat = st.button("Clear Chat", use_container_width=True)
|
364 |
-
|
365 |
-
# Initialize chat history
|
366 |
-
if "chat_history" not in st.session_state:
|
367 |
-
st.session_state.chat_history = []
|
368 |
|
369 |
-
|
370 |
-
|
371 |
-
st.session_state.chat_history
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
389 |
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|
390 |
-
|
391 |
-
# Store supporting chunks in session state for the right sidebar
|
392 |
st.session_state.selected_chunks = supporting_chunks
|
393 |
-
|
394 |
-
|
395 |
-
question = ""
|
396 |
-
|
397 |
-
except Exception as e:
|
398 |
-
st.error(f"Failed to generate answer: {str(e)}")
|
399 |
-
st.session_state.selected_chunks = []
|
400 |
-
|
401 |
-
# Display chat history
|
402 |
-
st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
|
403 |
-
|
404 |
-
if not st.session_state.chat_history:
|
405 |
-
# Show empty chat state with icon
|
406 |
-
st.markdown("""
|
407 |
-
<div class='empty-chat-placeholder'>
|
408 |
-
<div class='empty-chat-icon'>💬</div>
|
409 |
-
<p>Ask questions about your document to start a conversation</p>
|
410 |
-
</div>
|
411 |
-
""", unsafe_allow_html=True)
|
412 |
-
else:
|
413 |
-
for message in st.session_state.chat_history:
|
414 |
-
if message["role"] == "user":
|
415 |
-
st.markdown(f"""
|
416 |
-
<div class='chat-message user-message'>
|
417 |
-
<div class='message-content'>
|
418 |
-
<p>{message["content"]}</p>
|
419 |
-
</div>
|
420 |
-
</div>
|
421 |
-
""", unsafe_allow_html=True)
|
422 |
-
else:
|
423 |
-
st.markdown(f"""
|
424 |
-
<div class='chat-message assistant-message'>
|
425 |
-
<div class='message-content'>
|
426 |
-
<p>{message["content"]}</p>
|
427 |
-
</div>
|
428 |
-
</div>
|
429 |
-
""", unsafe_allow_html=True)
|
430 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
431 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
432 |
|
433 |
# --- Supporting Evidence in the right column ---
|
434 |
with evidence_col:
|
435 |
-
if st.session_state.
|
436 |
st.markdown("### Supporting Evidence")
|
437 |
|
438 |
if not st.session_state.selected_chunks:
|
439 |
st.info("Evidence chunks will appear here after you ask a question.")
|
440 |
else:
|
441 |
for idx, chunk in enumerate(st.session_state.selected_chunks):
|
442 |
-
with st.expander(f"Evidence #{idx+1}", expanded=True):
|
443 |
-
st.markdown(f"**Type:** {chunk['type'].capitalize()}")
|
444 |
st.markdown(chunk.get('narration', 'No narration available'))
|
445 |
-
|
446 |
-
# Display table if available
|
447 |
if 'table_structure' in chunk:
|
448 |
-
st.write("**Table Data:**")
|
449 |
st.dataframe(chunk['table_structure'], use_container_width=True)
|
450 |
-
|
451 |
-
# Display images if available
|
452 |
for blk in chunk.get('blocks', []):
|
453 |
-
if blk.get('type') == 'img_path' and 'images_dir' in st.session_state.
|
454 |
-
img_path = os.path.join(st.session_state.
|
455 |
if os.path.exists(img_path):
|
456 |
st.image(img_path, use_column_width=True)
|
457 |
|
|
|
3 |
from datetime import datetime
|
4 |
import re
|
5 |
from werkzeug.utils import secure_filename
|
6 |
+
import fitz # PyMuPDF
|
7 |
+
import base64
|
8 |
|
9 |
from src.gpp import GPP, GPPConfig
|
10 |
from src.qa import AnswerGenerator
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# --- Page Configuration ---
|
13 |
st.set_page_config(
|
14 |
+
page_title="Document Intelligence",
|
15 |
+
page_icon="🤖",
|
16 |
layout="wide"
|
17 |
)
|
18 |
|
19 |
# --- Session State Initialization ---
|
20 |
if 'chat_history' not in st.session_state:
|
21 |
+
st.session_state.chat_history = []
|
22 |
+
if 'parsed_info' not in st.session_state:
|
23 |
+
st.session_state.parsed_info = None # Will store {collection_name, layout_pdf, md_path, etc.}
|
24 |
if "selected_chunks" not in st.session_state:
|
25 |
st.session_state.selected_chunks = []
|
|
|
|
|
26 |
|
27 |
+
# --- Custom CSS for Messenger-like UI ---
|
28 |
st.markdown(
|
29 |
"""
|
30 |
<style>
|
31 |
+
/* Main app background */
|
32 |
+
.stApp {
|
33 |
+
background-color: #121212; /* Dark background */
|
34 |
+
color: #EAEAEA; /* Light text */
|
35 |
}
|
36 |
+
|
37 |
+
/* Ensure all text in the main content area is light */
|
38 |
+
.st-emotion-cache-16txtl3,
|
39 |
+
.st-emotion-cache-16txtl3 h1,
|
40 |
+
.st-emotion-cache-16txtl3 h2,
|
41 |
+
.st-emotion-cache-16txtl3 h3 {
|
42 |
+
color: #EAEAEA;
|
43 |
}
|
44 |
|
45 |
+
/* Sidebar adjustments */
|
46 |
+
.st-emotion-cache-16txtl3 {
|
47 |
+
padding-top: 2rem;
|
|
|
|
|
|
|
|
|
48 |
}
|
49 |
+
|
50 |
+
/* Main chat window container */
|
51 |
+
.chat-window {
|
52 |
+
height: 75vh;
|
53 |
+
background: #1E1E1E; /* Slightly lighter dark for chat window */
|
54 |
+
border-radius: 10px;
|
55 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.4);
|
56 |
+
display: flex;
|
57 |
+
flex-direction: column;
|
58 |
+
overflow: hidden;
|
59 |
}
|
60 |
+
|
61 |
+
/* Chat message history */
|
62 |
+
.chat-history {
|
63 |
+
flex-grow: 1;
|
64 |
+
overflow-y: auto;
|
65 |
+
padding: 20px;
|
66 |
+
display: flex;
|
67 |
+
flex-direction: column;
|
68 |
+
gap: 15px;
|
69 |
}
|
70 |
+
|
71 |
+
/* General message styling */
|
72 |
+
.message-row {
|
73 |
+
display: flex;
|
74 |
+
align-items: flex-end;
|
75 |
+
gap: 10px;
|
76 |
}
|
77 |
+
|
78 |
+
/* Assistant message alignment */
|
79 |
+
.assistant-row {
|
80 |
+
justify-content: flex-start;
|
|
|
|
|
|
|
81 |
}
|
82 |
+
|
83 |
+
/* User message alignment */
|
84 |
+
.user-row {
|
85 |
+
justify-content: flex-end;
|
86 |
}
|
87 |
+
|
88 |
+
/* Avatar styling */
|
89 |
+
.avatar {
|
90 |
+
width: 40px;
|
91 |
+
height: 40px;
|
92 |
+
border-radius: 50%;
|
93 |
+
display: flex;
|
94 |
+
align-items: center;
|
95 |
+
justify-content: center;
|
96 |
+
font-size: 20px;
|
97 |
+
background-color: #3A3B3C; /* Dark gray for avatar */
|
98 |
+
color: white;
|
99 |
}
|
100 |
+
|
101 |
+
/* Chat bubble styling */
|
102 |
+
.message-bubble {
|
103 |
+
max-width: 70%;
|
104 |
+
padding: 10px 15px;
|
105 |
+
border-radius: 18px;
|
106 |
+
overflow-wrap: break-word;
|
107 |
+
color: #EAEAEA; /* Light text for all bubbles */
|
108 |
}
|
109 |
|
110 |
+
.message-bubble p {
|
111 |
+
margin: 0;
|
|
|
|
|
|
|
|
|
|
|
112 |
}
|
113 |
+
|
114 |
+
/* Assistant bubble color */
|
115 |
+
.assistant-bubble {
|
116 |
+
background-color: #3A3B3C; /* Dark gray for assistant */
|
117 |
}
|
118 |
+
|
119 |
+
/* User bubble color */
|
120 |
+
.user-bubble {
|
121 |
+
background-color: #0084FF;
|
122 |
+
color: white; /* White text for user bubble */
|
123 |
}
|
124 |
+
|
125 |
+
/* Chat input container */
|
126 |
+
.chat-input-container {
|
127 |
+
padding: 15px 20px;
|
128 |
+
background: #1E1E1E; /* Match chat window background */
|
129 |
+
border-top: 1px solid #3A3B3C;
|
130 |
}
|
131 |
|
132 |
+
/* Input field styling */
|
133 |
+
.stTextInput>div>div>input {
|
134 |
border-radius: 18px;
|
135 |
+
border: 1px solid #555;
|
136 |
+
background-color: #3A3B3C; /* Dark input field */
|
137 |
+
color: #EAEAEA; /* Light text in input */
|
138 |
+
padding: 10px 15px;
|
139 |
}
|
140 |
+
|
141 |
+
/* Button styling */
|
142 |
+
.stButton>button {
|
143 |
+
border-radius: 18px;
|
144 |
+
border: none;
|
145 |
+
background-color: #0084FF;
|
146 |
color: white;
|
147 |
+
height: 42px;
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
}
|
149 |
|
150 |
+
/* Hide the default "Get Answer" header for a cleaner look */
|
151 |
+
.st-emotion-cache-16txtl3 > h1 {
|
152 |
+
display: none;
|
153 |
}
|
154 |
+
|
155 |
+
/* Empty chat placeholder */
|
156 |
.empty-chat-placeholder {
|
157 |
+
flex-grow: 1;
|
158 |
display: flex;
|
159 |
flex-direction: column;
|
|
|
160 |
justify-content: center;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
align-items: center;
|
162 |
+
color: #A0A0A0; /* Lighter gray for placeholder text */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
}
|
164 |
|
165 |
+
.empty-chat-placeholder .icon {
|
166 |
+
font-size: 50px;
|
167 |
+
margin-bottom: 10px;
|
168 |
}
|
169 |
|
|
|
|
|
|
|
|
|
170 |
</style>
|
171 |
""", unsafe_allow_html=True
|
172 |
)
|
|
|
176 |
# App info section
|
177 |
st.image("https://img.icons8.com/ios-filled/50/4A90E2/document.png", width=40)
|
178 |
st.title("Document Intelligence")
|
179 |
+
st.caption(f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
180 |
|
181 |
with st.expander("How It Works", expanded=True):
|
182 |
+
st.markdown("1. **Upload & Parse**: Select your PDF to begin.\n2. **Ask Questions**: Use the chat to query your document.\n3. **Get Answers**: The AI provides instant, evidence-backed responses.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
st.markdown("---")
|
185 |
|
|
|
188 |
uploaded_file = st.file_uploader("Select a PDF", type=["pdf"], help="Upload a PDF file to analyze")
|
189 |
|
190 |
if uploaded_file:
|
191 |
+
filename = secure_filename(uploaded_file.name)
|
192 |
+
# Sanitize filename to be a valid Chroma collection name
|
193 |
+
collection_name = re.sub(r'[^a-zA-Z0-9_-]', '_', os.path.splitext(filename)[0])
|
194 |
+
|
195 |
+
if st.button("Parse Document", use_container_width=True, key="parse_button"):
|
196 |
+
output_dir = os.path.join("./parsed", filename)
|
197 |
+
os.makedirs(output_dir, exist_ok=True)
|
198 |
+
pdf_path = os.path.join(output_dir, filename)
|
199 |
+
|
200 |
+
with open(pdf_path, "wb") as f:
|
201 |
+
f.write(uploaded_file.getbuffer())
|
202 |
+
|
203 |
+
with st.spinner("Processing document..."):
|
204 |
+
try:
|
205 |
+
gpp = GPP(GPPConfig())
|
206 |
+
parsed_info = gpp.run(pdf_path, output_dir, collection_name)
|
207 |
+
st.session_state.parsed_info = parsed_info
|
208 |
+
st.session_state.chat_history = []
|
209 |
+
st.session_state.selected_chunks = []
|
210 |
+
st.success("Document ready!")
|
211 |
+
except Exception as e:
|
212 |
+
st.error(f"Processing failed: {str(e)}")
|
213 |
+
st.session_state.parsed_info = None
|
214 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
# Display document preview if parsed
|
216 |
+
if st.session_state.parsed_info:
|
217 |
st.markdown("---")
|
218 |
st.subheader("Document Preview")
|
219 |
+
parsed = st.session_state.parsed_info
|
220 |
|
221 |
# Layout PDF
|
222 |
layout_pdf = parsed.get("layout_pdf")
|
223 |
if layout_pdf and os.path.exists(layout_pdf):
|
224 |
with st.expander("View Layout PDF", expanded=False):
|
225 |
st.markdown(f"[Open in new tab]({layout_pdf})")
|
226 |
+
doc = fitz.open(layout_pdf)
|
227 |
+
thumb_width = 500
|
228 |
+
thumbs = []
|
229 |
+
for page_num in range(len(doc)):
|
230 |
+
page = doc.load_page(page_num)
|
231 |
+
pix = page.get_pixmap(matrix=fitz.Matrix(thumb_width / page.rect.width, thumb_width / page.rect.width))
|
232 |
+
img_bytes = pix.tobytes("png")
|
233 |
+
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
234 |
+
thumbs.append((page_num, b64))
|
235 |
+
st.markdown("<div style='overflow-x: auto; white-space: nowrap; border: 1px solid #eee; border-radius: 8px; padding: 8px; background: #fafbfc; max-width: 100%;'>", unsafe_allow_html=True)
|
236 |
+
for page_num, b64 in thumbs:
|
237 |
+
st.markdown(f"<a href='{layout_pdf}#page={page_num+1}' target='_blank' style='display:inline-block;margin-right:8px;'><img src='data:image/png;base64,{b64}' width='{thumb_width}' style='border:1px solid #ccc;border-radius:4px;box-shadow:0 1px 2px #0001;'/></a>", unsafe_allow_html=True)
|
238 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
239 |
|
240 |
# Content preview
|
241 |
md_path = parsed.get("md_path")
|
|
|
248 |
except Exception as e:
|
249 |
st.warning(f"Could not preview content: {str(e)}")
|
250 |
|
251 |
+
st.markdown("---")
|
252 |
+
st.subheader("Chat Controls")
|
253 |
+
if st.button("Clear Chat", use_container_width=True):
|
254 |
+
st.session_state.chat_history = []
|
255 |
+
st.session_state.selected_chunks = []
|
256 |
+
st.rerun()
|
257 |
|
258 |
+
# --- Main Chat Area ---
|
259 |
+
main_col, evidence_col = st.columns([2, 1])
|
|
|
|
|
260 |
|
261 |
+
with main_col:
|
262 |
+
if not st.session_state.parsed_info:
|
263 |
+
st.info("Please upload and parse a document to start the chat.")
|
264 |
else:
|
265 |
+
# Create a container for the chat window
|
266 |
+
st.markdown("<div class='chat-window'>", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
+
# Display chat history
|
269 |
+
st.markdown("<div class='chat-history'>", unsafe_allow_html=True)
|
270 |
+
if not st.session_state.chat_history:
|
271 |
+
st.markdown("""
|
272 |
+
<div class='empty-chat-placeholder'>
|
273 |
+
<span class="icon">🤖</span>
|
274 |
+
<h3>Ask me anything about your document!</h3>
|
275 |
+
</div>
|
276 |
+
""", unsafe_allow_html=True)
|
277 |
+
else:
|
278 |
+
for message in st.session_state.chat_history:
|
279 |
+
if message["role"] == "user":
|
280 |
+
st.markdown(f"""
|
281 |
+
<div class="message-row user-row">
|
282 |
+
<div class="message-bubble user-bubble">
|
283 |
+
<p>{message["content"]}</p>
|
284 |
+
</div>
|
285 |
+
</div>
|
286 |
+
""", unsafe_allow_html=True)
|
287 |
+
else:
|
288 |
+
st.markdown(f"""
|
289 |
+
<div class="message-row assistant-row">
|
290 |
+
<div class="avatar">🤖</div>
|
291 |
+
<div class="message-bubble assistant-bubble">
|
292 |
+
<p>{message["content"]}</p>
|
293 |
+
</div>
|
294 |
+
</div>
|
295 |
+
""", unsafe_allow_html=True)
|
296 |
+
st.markdown("</div>", unsafe_allow_html=True) # Close chat-history
|
297 |
|
298 |
+
# Chat input bar
|
299 |
+
st.markdown("<div class='chat-input-container'>", unsafe_allow_html=True)
|
300 |
+
input_col, button_col = st.columns([4, 1])
|
301 |
+
with input_col:
|
302 |
+
question = st.text_input("Ask a question...", key="question_input", label_visibility="collapsed")
|
303 |
+
with button_col:
|
304 |
+
send_button = st.button("Send", use_container_width=True)
|
305 |
+
|
306 |
+
st.markdown("</div>", unsafe_allow_html=True) # Close chat-input-container
|
307 |
+
st.markdown("</div>", unsafe_allow_html=True) # Close chat-window
|
308 |
+
|
309 |
+
# --- Handle message sending ---
|
310 |
+
if send_button and question:
|
311 |
+
st.session_state.chat_history.append({"role": "user", "content": question})
|
312 |
+
|
313 |
+
with st.spinner("Thinking..."):
|
314 |
+
generator = AnswerGenerator(st.session_state.parsed_info['collection_name'])
|
315 |
+
answer, supporting_chunks = generator.answer(question)
|
316 |
st.session_state.chat_history.append({"role": "assistant", "content": answer})
|
|
|
|
|
317 |
st.session_state.selected_chunks = supporting_chunks
|
318 |
+
|
319 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
# --- Supporting Evidence in the right column ---
|
322 |
with evidence_col:
|
323 |
+
if st.session_state.parsed_info:
|
324 |
st.markdown("### Supporting Evidence")
|
325 |
|
326 |
if not st.session_state.selected_chunks:
|
327 |
st.info("Evidence chunks will appear here after you ask a question.")
|
328 |
else:
|
329 |
for idx, chunk in enumerate(st.session_state.selected_chunks):
|
330 |
+
with st.expander(f"Evidence Chunk #{idx+1}", expanded=True):
|
|
|
331 |
st.markdown(chunk.get('narration', 'No narration available'))
|
|
|
|
|
332 |
if 'table_structure' in chunk:
|
|
|
333 |
st.dataframe(chunk['table_structure'], use_container_width=True)
|
|
|
|
|
334 |
for blk in chunk.get('blocks', []):
|
335 |
+
if blk.get('type') == 'img_path' and 'images_dir' in st.session_state.parsed_info:
|
336 |
+
img_path = os.path.join(st.session_state.parsed_info['images_dir'], blk.get('img_path',''))
|
337 |
if os.path.exists(img_path):
|
338 |
st.image(img_path, use_column_width=True)
|
339 |
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
# Core
|
2 |
streamlit>=1.25.0
|
3 |
-
sentence-transformers>=2.2.2
|
4 |
-
rank-bm25>=0.2.2
|
5 |
-
hnswlib>=0.7.0
|
|
|
6 |
huggingface-hub>=0.16.4
|
7 |
langchain>=0.1.9
|
8 |
langchain-openai>=0.1.9
|
@@ -21,7 +22,7 @@ scikit-learn>=1.0.2
|
|
21 |
pdfminer.six>=20231228
|
22 |
torch>=2.6.0
|
23 |
torchvision
|
24 |
-
matplotlib>=3.10
|
25 |
ultralytics>=8.3.48
|
26 |
rapid-table>=1.0.3,<2.0.0
|
27 |
doclayout-yolo==0.0.2b1
|
@@ -30,7 +31,7 @@ PyYAML>=6.0.2,<7
|
|
30 |
ftfy>=6.3.1,<7
|
31 |
openai>=1.70.0,<2
|
32 |
pydantic>=2.7.2,<2.11
|
33 |
-
transformers>=4.49.0,<5.0.0
|
34 |
gradio-pdf>=0.0.21
|
35 |
shapely>=2.0.7,<3
|
36 |
pyclipper>=1.3.0,<2
|
|
|
1 |
# Core
|
2 |
streamlit>=1.25.0
|
3 |
+
sentence-transformers>=2.2.2 # Re-enabled for local embeddings
|
4 |
+
# rank-bm25>=0.2.2 - Replaced by ChromaDB
|
5 |
+
# hnswlib>=0.7.0 - Replaced by ChromaDB
|
6 |
+
chromadb>=0.4.18
|
7 |
huggingface-hub>=0.16.4
|
8 |
langchain>=0.1.9
|
9 |
langchain-openai>=0.1.9
|
|
|
22 |
pdfminer.six>=20231228
|
23 |
torch>=2.6.0
|
24 |
torchvision
|
25 |
+
# matplotlib>=3.10 - Removed, not used in the app
|
26 |
ultralytics>=8.3.48
|
27 |
rapid-table>=1.0.3,<2.0.0
|
28 |
doclayout-yolo==0.0.2b1
|
|
|
31 |
ftfy>=6.3.1,<7
|
32 |
openai>=1.70.0,<2
|
33 |
pydantic>=2.7.2,<2.11
|
34 |
+
# transformers>=4.49.0,<5.0.0 - Removed as reranker is disabled
|
35 |
gradio-pdf>=0.0.21
|
36 |
shapely>=2.0.7,<3
|
37 |
pyclipper>=1.3.0,<2
|
src/__init__.py
CHANGED
@@ -2,6 +2,11 @@ import os
|
|
2 |
from dotenv import load_dotenv
|
3 |
import bleach
|
4 |
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
load_dotenv()
|
7 |
|
@@ -14,35 +19,54 @@ Central configuration for the entire Document Intelligence app.
|
|
14 |
All modules import from here rather than hard-coding values.
|
15 |
"""
|
16 |
|
17 |
-
|
18 |
-
"OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002"
|
19 |
-
)
|
20 |
class EmbeddingConfig:
|
21 |
-
PROVIDER = os.getenv("EMBEDDING_PROVIDER",'
|
22 |
TEXT_MODEL = os.getenv('TEXT_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
23 |
-
META_MODEL = os.getenv('META_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
24 |
|
|
|
25 |
class RetrieverConfig:
|
26 |
-
|
27 |
-
TOP_K = int(os.getenv('RETRIEVER_TOP_K',
|
28 |
-
DENSE_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
29 |
-
ANN_TOP = int(os.getenv('ANN_TOP', 50))
|
30 |
-
|
31 |
-
class RerankerConfig:
|
32 |
-
@staticmethod
|
33 |
-
def get_device():
|
34 |
-
import torch
|
35 |
-
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
36 |
-
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
|
37 |
-
DEVICE = get_device()
|
38 |
|
|
|
39 |
class GPPConfig:
|
40 |
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 256))
|
41 |
DEDUP_SIM_THRESHOLD = float(os.getenv('DEDUP_SIM_THRESHOLD', 0.9))
|
42 |
EXPANSION_SIM_THRESHOLD = float(os.getenv('EXPANSION_SIM_THRESHOLD', 0.85))
|
43 |
COREF_CONTEXT_SIZE = int(os.getenv('COREF_CONTEXT_SIZE', 3))
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
|
|
2 |
from dotenv import load_dotenv
|
3 |
import bleach
|
4 |
from loguru import logger
|
5 |
+
import streamlit as st
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import torch
|
8 |
+
import chromadb
|
9 |
+
from src.utils import OpenAIEmbedder, LocalEmbedder
|
10 |
|
11 |
load_dotenv()
|
12 |
|
|
|
19 |
All modules import from here rather than hard-coding values.
|
20 |
"""
|
21 |
|
22 |
+
# --- Embedding & ChromaDB Config ---
|
|
|
|
|
23 |
class EmbeddingConfig:
|
24 |
+
PROVIDER = os.getenv("EMBEDDING_PROVIDER", 'local')
|
25 |
TEXT_MODEL = os.getenv('TEXT_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
|
|
26 |
|
27 |
+
# --- Retriever Config for Low Latency ---
|
28 |
class RetrieverConfig:
|
29 |
+
# Retrieve more chunks initially, let the final prompt handle trimming.
|
30 |
+
TOP_K = int(os.getenv('RETRIEVER_TOP_K', 5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
# --- GPP Config ---
|
33 |
class GPPConfig:
|
34 |
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 256))
|
35 |
DEDUP_SIM_THRESHOLD = float(os.getenv('DEDUP_SIM_THRESHOLD', 0.9))
|
36 |
EXPANSION_SIM_THRESHOLD = float(os.getenv('EXPANSION_SIM_THRESHOLD', 0.85))
|
37 |
COREF_CONTEXT_SIZE = int(os.getenv('COREF_CONTEXT_SIZE', 3))
|
38 |
+
|
39 |
+
# --- Centralized, Streamlit-cached Clients & Models ---
|
40 |
+
@st.cache_resource(show_spinner="Connecting to ChromaDB...")
|
41 |
+
def get_chroma_client():
|
42 |
+
"""
|
43 |
+
Initializes a ChromaDB client.
|
44 |
+
Defaults to a serverless, persistent client, which is ideal for local
|
45 |
+
development and single-container deployments.
|
46 |
+
If CHROMA_HOST is set, it will attempt to connect to a standalone server.
|
47 |
+
"""
|
48 |
+
chroma_host = os.getenv("CHROMA_HOST")
|
49 |
+
|
50 |
+
if chroma_host:
|
51 |
+
logger.info(f"Connecting to ChromaDB server at {chroma_host}...")
|
52 |
+
client = chromadb.HttpClient(
|
53 |
+
host=chroma_host,
|
54 |
+
port=int(os.getenv("CHROMA_PORT", "8000"))
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
persist_directory = os.getenv("PERSIST_DIRECTORY", "./parsed/chroma_db")
|
58 |
+
logger.info(f"Using persistent ChromaDB at: {persist_directory}")
|
59 |
+
client = chromadb.PersistentClient(path=persist_directory)
|
60 |
+
|
61 |
+
return client
|
62 |
+
|
63 |
+
@st.cache_resource(show_spinner="Loading embedding model...")
|
64 |
+
def get_embedder():
|
65 |
+
if EmbeddingConfig.PROVIDER == "openai":
|
66 |
+
logger.info(f"Using OpenAI embedder with model: {EmbeddingConfig.TEXT_MODEL}")
|
67 |
+
return OpenAIEmbedder(model_name=EmbeddingConfig.TEXT_MODEL)
|
68 |
+
else:
|
69 |
+
logger.info(f"Using local embedder with model: {EmbeddingConfig.TEXT_MODEL}")
|
70 |
+
return LocalEmbedder(model_name=EmbeddingConfig.TEXT_MODEL)
|
71 |
|
72 |
|
src/ghm.py
CHANGED
@@ -33,8 +33,8 @@ if __name__ == '__main__':
|
|
33 |
mineru_patterns = [
|
34 |
# "models/Layout/LayoutLMv3/*",
|
35 |
"models/Layout/YOLO/*",
|
36 |
-
"models/MFD/YOLO/*",
|
37 |
-
"models/MFR/unimernet_hf_small_2503/*",
|
38 |
"models/OCR/paddleocr_torch/*",
|
39 |
# "models/TabRec/TableMaster/*",
|
40 |
# "models/TabRec/StructEqTable/*",
|
|
|
33 |
mineru_patterns = [
|
34 |
# "models/Layout/LayoutLMv3/*",
|
35 |
"models/Layout/YOLO/*",
|
36 |
+
# "models/MFD/YOLO/*",
|
37 |
+
# "models/MFR/unimernet_hf_small_2503/*",
|
38 |
"models/OCR/paddleocr_torch/*",
|
39 |
# "models/TabRec/TableMaster/*",
|
40 |
# "models/TabRec/StructEqTable/*",
|
src/gpp.py
CHANGED
@@ -17,10 +17,10 @@ import os
|
|
17 |
import json
|
18 |
from typing import List, Dict, Any, Optional
|
19 |
import re
|
|
|
20 |
|
21 |
-
from src import EmbeddingConfig, GPPConfig
|
22 |
from src.utils import OpenAIEmbedder, LLMClient
|
23 |
-
from src import logger
|
24 |
|
25 |
def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
26 |
"""
|
@@ -49,21 +49,8 @@ def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
|
49 |
class GPP:
|
50 |
def __init__(self, config: GPPConfig):
|
51 |
self.config = config
|
52 |
-
|
53 |
-
|
54 |
-
# Embedding models
|
55 |
-
if EmbeddingConfig.PROVIDER == "openai":
|
56 |
-
self.text_embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
57 |
-
self.meta_embedder = OpenAIEmbedder(EmbeddingConfig.META_MODEL)
|
58 |
-
else:
|
59 |
-
self.text_embedder = SentenceTransformer(
|
60 |
-
EmbeddingConfig.TEXT_MODEL, use_auth_token=True
|
61 |
-
)
|
62 |
-
self.meta_embedder = SentenceTransformer(
|
63 |
-
EmbeddingConfig.META_MODEL, use_auth_token=True
|
64 |
-
)
|
65 |
-
|
66 |
-
self.bm25 = None
|
67 |
|
68 |
def parse_pdf(self, pdf_path: str, output_dir: str) -> Dict[str, Any]:
|
69 |
"""
|
@@ -168,27 +155,23 @@ class GPP:
|
|
168 |
|
169 |
def deduplicate(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
170 |
try:
|
171 |
-
# Lazy import heavy libraries
|
172 |
-
import numpy as np
|
173 |
-
from sentence_transformers import SentenceTransformer
|
174 |
-
|
175 |
narrations = [c.get("narration", "") for c in chunks]
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
if not any(
|
184 |
-
(emb @ embs[j]).item()
|
185 |
-
/ (np.linalg.norm(emb) * np.linalg.norm(embs[j]) + 1e-8)
|
186 |
-
> self.config.DEDUP_SIM_THRESHOLD
|
187 |
-
for j in keep
|
188 |
-
):
|
189 |
-
keep.append(i)
|
190 |
-
deduped = [chunks[i] for i in keep]
|
191 |
-
logger.info(f"Deduplicated: {len(chunks)}→{len(deduped)}")
|
192 |
return deduped
|
193 |
except Exception as e:
|
194 |
logger.error(f"Deduplication failed: {e}")
|
@@ -198,7 +181,7 @@ class GPP:
|
|
198 |
for idx, c in enumerate(chunks):
|
199 |
start = max(0, idx - self.config.COREF_CONTEXT_SIZE)
|
200 |
ctx = "\n".join(chunks[i].get("narration", "") for i in range(start, idx))
|
201 |
-
prompt = f"Context:\n{ctx}\nRewrite pronouns in:\n{c.get('narration', '')}"
|
202 |
try:
|
203 |
c["narration"] = LLMClient.generate(prompt)
|
204 |
except Exception as e:
|
@@ -212,134 +195,68 @@ class GPP:
|
|
212 |
for sec, items in sections.items():
|
213 |
blob = "\n".join(i.get("narration", "") for i in items)
|
214 |
try:
|
215 |
-
summ = LLMClient.generate(f"Summarize this section:\n{blob}")
|
216 |
for i in items:
|
217 |
i.setdefault("metadata", {})["section_summary"] = summ
|
218 |
except Exception as e:
|
219 |
logger.error(f"Metadata summarization failed for section {sec}: {e}")
|
220 |
|
221 |
-
def
|
222 |
-
"""
|
223 |
-
Build BM25 index on token lists for sparse retrieval.
|
224 |
-
"""
|
225 |
-
# Lazy import heavy libraries
|
226 |
-
from rank_bm25 import BM25Okapi
|
227 |
-
|
228 |
-
tokenized = [c["narration"].split() for c in chunks]
|
229 |
-
self.bm25 = BM25Okapi(tokenized)
|
230 |
-
|
231 |
-
def compute_and_store(self, chunks: List[Dict[str, Any]], output_dir: str) -> None:
|
232 |
"""
|
233 |
-
|
234 |
-
and section_summary (meta_vec).
|
235 |
-
2. Build two HNSWlib indices (one for text_vecs, one for meta_vecs).
|
236 |
-
3. Save both indices to disk.
|
237 |
-
4. Dump human-readable chunk metadata (incl. section_summary)
|
238 |
-
for traceability in the UI.
|
239 |
"""
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
from sentence_transformers import SentenceTransformer
|
244 |
-
|
245 |
-
# --- 1. Prepare embedder ---
|
246 |
-
if EmbeddingConfig.PROVIDER.lower() == "openai":
|
247 |
-
embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
248 |
-
embed_fn = embedder.embed
|
249 |
-
else:
|
250 |
-
st_model = SentenceTransformer(
|
251 |
-
EmbeddingConfig.TEXT_MODEL, use_auth_token=True
|
252 |
-
)
|
253 |
-
embed_fn = lambda texts: st_model.encode(
|
254 |
-
texts, show_progress_bar=False
|
255 |
-
).tolist()
|
256 |
-
|
257 |
-
# Batch compute text & meta embeddings ---
|
258 |
-
narrations = [c["narration"] for c in chunks]
|
259 |
-
meta_texts = [c.get("section_summary", "") for c in chunks]
|
260 |
-
logger.info(
|
261 |
-
"computing_embeddings",
|
262 |
-
provider=EmbeddingConfig.PROVIDER,
|
263 |
-
num_chunks=len(chunks),
|
264 |
-
)
|
265 |
-
|
266 |
-
text_vecs = embed_fn(narrations)
|
267 |
-
meta_vecs = embed_fn(meta_texts)
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
)
|
|
|
|
|
|
|
|
|
273 |
|
274 |
-
|
275 |
-
text_matrix = np.vstack(text_vecs).astype(np.float32)
|
276 |
-
meta_matrix = np.vstack(meta_vecs).astype(np.float32)
|
277 |
-
|
278 |
-
# Build HNSW indices ---
|
279 |
-
dim = text_matrix.shape[1]
|
280 |
-
text_index = hnswlib.Index(space="cosine", dim=dim)
|
281 |
-
text_index.init_index(
|
282 |
-
max_elements=len(chunks),
|
283 |
-
ef_construction=GPPConfig.HNSW_EF_CONSTRUCTION,
|
284 |
-
M=GPPConfig.HNSW_M,
|
285 |
-
)
|
286 |
-
ids = [c["id"] for c in chunks]
|
287 |
-
text_index.add_items(text_matrix, ids)
|
288 |
-
text_index.set_ef(GPPConfig.HNSW_EF_SEARCH)
|
289 |
-
logger.info("text_hnsw_built", elements=len(chunks))
|
290 |
-
|
291 |
-
# Meta index (same dim)
|
292 |
-
meta_index = hnswlib.Index(space="cosine", dim=dim)
|
293 |
-
meta_index.init_index(
|
294 |
-
max_elements=len(chunks),
|
295 |
-
ef_construction=GPPConfig.HNSW_EF_CONSTRUCTION,
|
296 |
-
M=GPPConfig.HNSW_M,
|
297 |
-
)
|
298 |
-
meta_index.add_items(meta_matrix, ids)
|
299 |
-
meta_index.set_ef(GPPConfig.HNSW_EF_SEARCH)
|
300 |
-
logger.info("meta_hnsw_built", elements=len(chunks))
|
301 |
-
|
302 |
-
# Persist indices to disk ---
|
303 |
-
text_idx_path = os.path.join(output_dir, "hnsw_text_index.bin")
|
304 |
-
meta_idx_path = os.path.join(output_dir, "hnsw_meta_index.bin")
|
305 |
-
text_index.save_index(text_idx_path)
|
306 |
-
meta_index.save_index(meta_idx_path)
|
307 |
-
logger.info(
|
308 |
-
"hnsw_indices_saved", text_index=text_idx_path, meta_index=meta_idx_path
|
309 |
-
)
|
310 |
-
|
311 |
-
# Dump chunk metadata for UI traceability ---
|
312 |
-
meta_path = os.path.join(output_dir, "chunk_metadata.json")
|
313 |
-
metadata = {
|
314 |
-
str(c["id"]): {
|
315 |
-
"text": c.get("text", ""),
|
316 |
-
"narration": c["narration"],
|
317 |
-
"type": c.get("type", ""),
|
318 |
-
"section_summary": c.get("section_summary", ""),
|
319 |
-
}
|
320 |
-
for c in chunks
|
321 |
-
}
|
322 |
-
with open(meta_path, "w", encoding="utf-8") as f:
|
323 |
-
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
324 |
-
logger.info("chunk_metadata_saved", path=meta_path)
|
325 |
-
|
326 |
-
def run(self, pdf_path: str, output_dir: str) -> Dict[str, Any]:
|
327 |
"""
|
328 |
-
Executes
|
329 |
-
|
330 |
"""
|
331 |
-
|
332 |
-
blocks =
|
|
|
333 |
chunks = self.chunk_blocks(blocks)
|
334 |
-
# assigning ID's to chuncks for traceability
|
335 |
for idx, chunk in enumerate(chunks):
|
336 |
chunk["id"] = idx
|
|
|
337 |
self.narrate_multimodal(chunks)
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
self.
|
342 |
-
self.
|
343 |
-
|
344 |
-
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import json
|
18 |
from typing import List, Dict, Any, Optional
|
19 |
import re
|
20 |
+
import numpy as np
|
21 |
|
22 |
+
from src import EmbeddingConfig, GPPConfig, logger, get_embedder, get_chroma_client
|
23 |
from src.utils import OpenAIEmbedder, LLMClient
|
|
|
24 |
|
25 |
def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
26 |
"""
|
|
|
49 |
class GPP:
|
50 |
def __init__(self, config: GPPConfig):
|
51 |
self.config = config
|
52 |
+
self.text_embedder = get_embedder()
|
53 |
+
self.chroma_client = get_chroma_client()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def parse_pdf(self, pdf_path: str, output_dir: str) -> Dict[str, Any]:
|
56 |
"""
|
|
|
155 |
|
156 |
def deduplicate(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
157 |
try:
|
|
|
|
|
|
|
|
|
158 |
narrations = [c.get("narration", "") for c in chunks]
|
159 |
+
embs = self.text_embedder.embed(narrations)
|
160 |
+
|
161 |
+
# Simple cosine similarity check
|
162 |
+
keep_indices = []
|
163 |
+
for i in range(len(embs)):
|
164 |
+
is_duplicate = False
|
165 |
+
for j_idx in keep_indices:
|
166 |
+
sim = np.dot(embs[i], embs[j_idx]) / (np.linalg.norm(embs[i]) * np.linalg.norm(embs[j_idx]))
|
167 |
+
if sim > self.config.DEDUP_SIM_THRESHOLD:
|
168 |
+
is_duplicate = True
|
169 |
+
break
|
170 |
+
if not is_duplicate:
|
171 |
+
keep_indices.append(i)
|
172 |
|
173 |
+
deduped = [chunks[i] for i in keep_indices]
|
174 |
+
logger.info(f"Deduplicated: {len(chunks)} -> {len(deduped)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
return deduped
|
176 |
except Exception as e:
|
177 |
logger.error(f"Deduplication failed: {e}")
|
|
|
181 |
for idx, c in enumerate(chunks):
|
182 |
start = max(0, idx - self.config.COREF_CONTEXT_SIZE)
|
183 |
ctx = "\n".join(chunks[i].get("narration", "") for i in range(start, idx))
|
184 |
+
prompt = f"Context:\n{ctx}\nRewrite pronouns in:\n{c.get('narration', '')}\n\n give only the rewritten text, no other text"
|
185 |
try:
|
186 |
c["narration"] = LLMClient.generate(prompt)
|
187 |
except Exception as e:
|
|
|
195 |
for sec, items in sections.items():
|
196 |
blob = "\n".join(i.get("narration", "") for i in items)
|
197 |
try:
|
198 |
+
summ = LLMClient.generate(f"Summarize this section:\n{blob}\n\n give only the summarized text, no other text")
|
199 |
for i in items:
|
200 |
i.setdefault("metadata", {})["section_summary"] = summ
|
201 |
except Exception as e:
|
202 |
logger.error(f"Metadata summarization failed for section {sec}: {e}")
|
203 |
|
204 |
+
def store_in_chroma(self, chunks: List[Dict[str, Any]], collection_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
"""
|
206 |
+
Computes embeddings and stores the chunks in a ChromaDB collection.
|
|
|
|
|
|
|
|
|
|
|
207 |
"""
|
208 |
+
if not chunks:
|
209 |
+
logger.warning("No chunks to store in ChromaDB.")
|
210 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
+
collection = self.chroma_client.get_or_create_collection(name=collection_name)
|
213 |
+
|
214 |
+
# Prepare data for ChromaDB
|
215 |
+
documents = [c['narration'] for c in chunks]
|
216 |
+
metadatas = []
|
217 |
+
for chunk in chunks:
|
218 |
+
# metadata can only contain str, int, float, bool
|
219 |
+
meta = {k: v for k, v in chunk.items() if k not in ['narration', 'text', 'id'] and type(v) in [str, int, float, bool]}
|
220 |
+
meta['text'] = chunk.get('text', '') # Add original text to metadata
|
221 |
+
metadatas.append(meta)
|
222 |
+
|
223 |
+
ids = [str(c['id']) for c in chunks]
|
224 |
+
|
225 |
+
logger.info(f"Storing {len(chunks)} chunks in ChromaDB collection '{collection_name}'...")
|
226 |
+
try:
|
227 |
+
collection.add(
|
228 |
+
ids=ids,
|
229 |
+
documents=documents,
|
230 |
+
metadatas=metadatas
|
231 |
)
|
232 |
+
logger.info("Successfully stored chunks in ChromaDB.")
|
233 |
+
except Exception as e:
|
234 |
+
logger.error(f"Failed to store chunks in ChromaDB: {e}")
|
235 |
+
raise
|
236 |
|
237 |
+
def run(self, pdf_path: str, output_dir: str, collection_name: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
"""
|
239 |
+
Executes a streamlined GPP: parse -> chunk -> narrate -> store.
|
240 |
+
Heavy enhancement steps are bypassed for maximum efficiency.
|
241 |
"""
|
242 |
+
parsed_output = self.parse_pdf(pdf_path, output_dir)
|
243 |
+
blocks = parsed_output.get("blocks", [])
|
244 |
+
|
245 |
chunks = self.chunk_blocks(blocks)
|
|
|
246 |
for idx, chunk in enumerate(chunks):
|
247 |
chunk["id"] = idx
|
248 |
+
|
249 |
self.narrate_multimodal(chunks)
|
250 |
+
|
251 |
+
# NOTE: Heavy enhancement steps are disabled for performance.
|
252 |
+
# To re-enable, uncomment the following lines:
|
253 |
+
# chunks = self.deduplicate(chunks)
|
254 |
+
# self.coref_resolution(chunks)
|
255 |
+
# self.metadata_summarization(chunks)
|
256 |
+
|
257 |
+
self.store_in_chroma(chunks, collection_name)
|
258 |
+
|
259 |
+
parsed_output["chunks"] = chunks
|
260 |
+
parsed_output["collection_name"] = collection_name
|
261 |
+
logger.info("GPP pipeline complete. Data stored in ChromaDB.")
|
262 |
+
return parsed_output
|
src/qa.py
CHANGED
@@ -9,94 +9,95 @@ This module contains:
|
|
9 |
Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
|
10 |
"""
|
11 |
import os
|
|
|
12 |
from typing import List, Dict, Any, Tuple
|
13 |
-
import streamlit as st
|
14 |
|
15 |
-
from src import
|
16 |
from src.utils import LLMClient
|
17 |
-
from src.retriever import Retriever
|
18 |
|
19 |
-
class
|
20 |
"""
|
21 |
-
|
|
|
|
|
22 |
"""
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def __init__(self, config: RerankerConfig):
|
33 |
-
try:
|
34 |
-
self.tokenizer, self.model = self.load_model_and_tokenizer(config.MODEL_NAME, config.DEVICE)
|
35 |
-
except Exception as e:
|
36 |
-
logger.error(f'Failed to load reranker model: {e}')
|
37 |
-
raise
|
38 |
-
|
39 |
-
def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int) -> List[Dict[str, Any]]:
|
40 |
-
"""Score each candidate and return top_k sorted by relevance."""
|
41 |
-
if not candidates:
|
42 |
-
logger.warning('No candidates provided to rerank.')
|
43 |
-
return []
|
44 |
-
try:
|
45 |
-
import torch
|
46 |
-
inputs = self.tokenizer(
|
47 |
-
[query] * len(candidates),
|
48 |
-
[c.get('narration', '') for c in candidates],
|
49 |
-
padding=True,
|
50 |
-
truncation=True,
|
51 |
-
return_tensors='pt'
|
52 |
-
).to(RerankerConfig.DEVICE)
|
53 |
-
with torch.no_grad():
|
54 |
-
out = self.model(**inputs)
|
55 |
-
|
56 |
-
logits = out.logits
|
57 |
-
if logits.ndim == 2 and logits.shape[1] == 1:
|
58 |
-
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1)
|
59 |
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
return [
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
"""
|
72 |
-
Main interface: initializes Retriever + Reranker once, then
|
73 |
-
answers multiple questions without re-loading models each time.
|
74 |
-
"""
|
75 |
-
def __init__(self, chunks: List[Dict[str, Any]]):
|
76 |
-
self.chunks = chunks
|
77 |
-
self.retriever = Retriever(chunks, RetrieverConfig)
|
78 |
-
self.reranker = Reranker(RerankerConfig)
|
79 |
-
self.top_k = RetrieverConfig.TOP_K // 2
|
80 |
-
|
81 |
-
def answer(
|
82 |
-
self, question: str
|
83 |
-
) -> Tuple[str, List[Dict[str, Any]]]:
|
84 |
candidates = self.retriever.retrieve(question)
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks)
|
|
|
|
|
87 |
prompt = (
|
88 |
-
"You are a
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
"
|
94 |
-
"
|
95 |
-
"
|
96 |
-
"
|
97 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
"Answer:"
|
99 |
)
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
return answer, top_chunks
|
102 |
|
|
|
9 |
Each component is modular and can be swapped or extended (e.g., add HyDE retriever).
|
10 |
"""
|
11 |
import os
|
12 |
+
import random
|
13 |
from typing import List, Dict, Any, Tuple
|
|
|
14 |
|
15 |
+
from src import logger, RetrieverConfig
|
16 |
from src.utils import LLMClient
|
17 |
+
from src.retriever import Retriever
|
18 |
|
19 |
+
class AnswerGenerator:
|
20 |
"""
|
21 |
+
Generates answers by retrieving documents from a vector store
|
22 |
+
and using them to build a context for an LLM.
|
23 |
+
This version is optimized for low latency by skipping the reranking step.
|
24 |
"""
|
25 |
+
def __init__(self, collection_name: str):
|
26 |
+
self.retriever = Retriever(collection_name, RetrieverConfig)
|
27 |
+
self.context_chunks_count = 5 # Use top 5 chunks for the final prompt
|
28 |
+
self.greetings = [
|
29 |
+
"Hello! I'm ready to answer your questions about the document. What would you like to know?",
|
30 |
+
"Hi there! How can I help you with your document today?",
|
31 |
+
"Hey! I've got the document open and I'm ready for your questions.",
|
32 |
+
"Greetings! Ask me anything about the document, and I'll do my best to find the answer for you."
|
33 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def _truncate_to_last_sentence(self, text: str) -> str:
|
36 |
+
"""Finds the last period or newline and truncates the text to that point."""
|
37 |
+
# Find the last period
|
38 |
+
last_period = text.rfind('.')
|
39 |
+
# Find the last newline
|
40 |
+
last_newline = text.rfind('\n')
|
41 |
+
# Find the last of the two
|
42 |
+
last_marker = max(last_period, last_newline)
|
43 |
|
44 |
+
if last_marker != -1:
|
45 |
+
return text[:last_marker + 1].strip()
|
46 |
+
|
47 |
+
# If no sentence-ending punctuation, return the text as is (or a portion)
|
48 |
+
return text
|
49 |
|
50 |
+
def answer(self, question: str) -> Tuple[str, List[Dict[str, Any]]]:
|
51 |
+
"""
|
52 |
+
Retrieves documents, builds a context, and generates an answer.
|
53 |
+
Handles simple greetings separately to improve user experience.
|
54 |
+
"""
|
55 |
+
# Handle simple greetings to avoid a failed retrieval
|
56 |
+
normalized_question = question.lower().strip().rstrip('.,!')
|
57 |
+
greeting_triggers = ["hi", "hello", "hey", "hallo", "hola"]
|
58 |
+
if normalized_question in greeting_triggers:
|
59 |
+
return random.choice(self.greetings), []
|
60 |
|
61 |
+
# Retrieve candidate documents from the vector store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
candidates = self.retriever.retrieve(question)
|
63 |
+
|
64 |
+
if not candidates:
|
65 |
+
logger.warning("No candidates retrieved from vector store.")
|
66 |
+
return "The document does not contain information on this topic.", []
|
67 |
+
|
68 |
+
# Use the top N chunks for context, without reranking
|
69 |
+
top_chunks = candidates[:self.context_chunks_count]
|
70 |
+
|
71 |
context = "\n\n".join(f"- {c['narration']}" for c in top_chunks)
|
72 |
+
|
73 |
+
# A more robust prompt that encourages a natural, conversational tone
|
74 |
prompt = (
|
75 |
+
"You are a helpful and friendly AI assistant for document analysis. "
|
76 |
+
"Your user is asking a question about a document. "
|
77 |
+
"Based *only* on the context provided below, formulate a clear and conversational answer. "
|
78 |
+
"Adopt a helpful and slightly informal tone, as if you were a knowledgeable colleague.\n\n"
|
79 |
+
"CONTEXT:\n"
|
80 |
+
"---------------------\n"
|
81 |
+
f"{context}\n"
|
82 |
+
"---------------------\n\n"
|
83 |
+
"USER'S QUESTION: "
|
84 |
+
f'"{question}"\n\n'
|
85 |
+
"YOUR TASK:\n"
|
86 |
+
"1. Carefully read the provided context.\n"
|
87 |
+
"2. If the context contains the answer, explain it to the user in a natural, conversational way. Do not just repeat the text verbatim.\n"
|
88 |
+
"3. If the context does not contain the necessary information, respond with: "
|
89 |
+
"'I've checked the document, but I couldn't find any information on that topic.'\n"
|
90 |
+
"4. **Crucially, do not use any information outside of the provided context.**\n\n"
|
91 |
"Answer:"
|
92 |
)
|
93 |
+
|
94 |
+
answer, finish_reason = LLMClient.generate(prompt, max_tokens=256)
|
95 |
+
|
96 |
+
# Handle cases where the response might be cut off
|
97 |
+
if finish_reason == 'length':
|
98 |
+
logger.warning("LLM response was truncated due to token limit.")
|
99 |
+
truncated_answer = self._truncate_to_last_sentence(answer)
|
100 |
+
answer = truncated_answer + " ... (response shortened)"
|
101 |
+
|
102 |
return answer, top_chunks
|
103 |
|
src/retriever.py
CHANGED
@@ -1,110 +1,62 @@
|
|
1 |
import os
|
2 |
from typing import List, Dict, Any
|
3 |
-
import
|
4 |
|
5 |
-
from src import RetrieverConfig, logger
|
6 |
|
7 |
class Retriever:
|
8 |
"""
|
9 |
-
|
10 |
"""
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
|
17 |
-
def
|
18 |
-
# Lazy import heavy libraries
|
19 |
-
import numpy as np
|
20 |
-
import hnswlib
|
21 |
-
from rank_bm25 import BM25Okapi
|
22 |
-
self.chunks = chunks
|
23 |
-
try:
|
24 |
-
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks):
|
25 |
-
logger.error("Chunks must be a list of dicts.")
|
26 |
-
raise ValueError("Chunks must be a list of dicts.")
|
27 |
-
corpus = [c.get('narration', '').split() for c in chunks]
|
28 |
-
self.bm25 = BM25Okapi(corpus)
|
29 |
-
self.embedder = self.load_embedder(config.DENSE_MODEL)
|
30 |
-
dim = len(self.embedder.encode(["test"])[0])
|
31 |
-
self.ann = hnswlib.Index(space='cosine', dim=dim)
|
32 |
-
self.ann.init_index(max_elements=len(chunks))
|
33 |
-
embeddings = self.embedder.encode([c.get('narration', '') for c in chunks])
|
34 |
-
self.ann.add_items(embeddings, ids=list(range(len(chunks))))
|
35 |
-
self.ann.set_ef(config.ANN_TOP)
|
36 |
-
except Exception as e:
|
37 |
-
logger.error(f"Retriever init failed: {e}")
|
38 |
-
self.bm25 = None
|
39 |
-
self.embedder = None
|
40 |
-
self.ann = None
|
41 |
-
|
42 |
-
def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
43 |
"""
|
44 |
-
|
45 |
-
|
46 |
-
Args:
|
47 |
-
query (str): Query string.
|
48 |
-
top_k (int): Number of top chunks to return.
|
49 |
-
|
50 |
-
Returns:
|
51 |
-
List[Dict[str, Any]]: List of top chunks.
|
52 |
"""
|
53 |
-
if
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
import numpy as np # Ensure np is defined here
|
59 |
-
scores = self.bm25.get_scores(tokenized)
|
60 |
-
top_indices = np.argsort(scores)[::-1][:top_k]
|
61 |
-
return [self.chunks[i] for i in top_indices]
|
62 |
-
except Exception as e:
|
63 |
-
logger.error(f"Sparse retrieval failed: {e}")
|
64 |
return []
|
65 |
|
66 |
-
def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
67 |
-
"""
|
68 |
-
Retrieve chunks using dense retrieval.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
query (str): Query string.
|
72 |
-
top_k (int): Number of top chunks to return.
|
73 |
-
|
74 |
-
Returns:
|
75 |
-
List[Dict[str, Any]]: List of top chunks.
|
76 |
-
"""
|
77 |
-
if not self.ann or not self.embedder:
|
78 |
-
logger.error("Dense retriever not initialized.")
|
79 |
-
return []
|
80 |
try:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
if top_k is None:
|
100 |
-
top_k = RetrieverConfig.TOP_K
|
101 |
-
sparse = self.retrieve_sparse(query, top_k)
|
102 |
-
dense = self.retrieve_dense(query, top_k)
|
103 |
-
seen = set()
|
104 |
-
combined = []
|
105 |
-
for c in sparse + dense:
|
106 |
-
cid = id(c)
|
107 |
-
if cid not in seen:
|
108 |
-
seen.add(cid)
|
109 |
-
combined.append(c)
|
110 |
-
return combined
|
|
|
1 |
import os
|
2 |
from typing import List, Dict, Any
|
3 |
+
import numpy as np
|
4 |
|
5 |
+
from src import RetrieverConfig, logger, get_chroma_client, get_embedder
|
6 |
|
7 |
class Retriever:
|
8 |
"""
|
9 |
+
Retrieves documents from a ChromaDB collection.
|
10 |
"""
|
11 |
+
def __init__(self, collection_name: str, config: RetrieverConfig):
|
12 |
+
self.collection_name = collection_name
|
13 |
+
self.config = config
|
14 |
+
self.client = get_chroma_client()
|
15 |
+
self.embedder = get_embedder()
|
16 |
+
self.collection = self.client.get_or_create_collection(name=self.collection_name)
|
17 |
|
18 |
+
def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
+
Embeds a query and retrieves the top_k most similar documents from ChromaDB.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
"""
|
22 |
+
if top_k is None:
|
23 |
+
top_k = self.config.TOP_K
|
24 |
+
|
25 |
+
if self.collection.count() == 0:
|
26 |
+
logger.warning(f"Chroma collection '{self.collection_name}' is empty. Cannot retrieve.")
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
return []
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
try:
|
30 |
+
# 1. Embed the query
|
31 |
+
query_embedding = self.embedder.embed([query])[0]
|
32 |
+
|
33 |
+
# 2. Query ChromaDB
|
34 |
+
results = self.collection.query(
|
35 |
+
query_embeddings=[query_embedding],
|
36 |
+
n_results=top_k,
|
37 |
+
include=["metadatas", "documents"]
|
38 |
+
)
|
39 |
+
|
40 |
+
# 3. Format results into chunks
|
41 |
+
# Chroma returns lists of lists, so we access the first element.
|
42 |
+
if not results or not results.get('ids', [[]])[0]:
|
43 |
+
return []
|
44 |
+
|
45 |
+
ids = results['ids'][0]
|
46 |
+
documents = results['documents'][0]
|
47 |
+
metadatas = results['metadatas'][0]
|
48 |
+
|
49 |
+
retrieved_chunks = []
|
50 |
+
for i, doc_id in enumerate(ids):
|
51 |
+
chunk = {
|
52 |
+
'id': doc_id,
|
53 |
+
'narration': documents[i],
|
54 |
+
**metadatas[i] # Add all other metadata from Chroma
|
55 |
+
}
|
56 |
+
retrieved_chunks.append(chunk)
|
57 |
+
|
58 |
+
return retrieved_chunks
|
59 |
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"ChromaDB retrieval failed for collection '{self.collection_name}': {e}")
|
62 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
CHANGED
@@ -6,6 +6,7 @@ import openai
|
|
6 |
from typing import List
|
7 |
from openai import AzureOpenAI
|
8 |
from langchain_openai import AzureOpenAIEmbeddings
|
|
|
9 |
from src import logger
|
10 |
|
11 |
|
@@ -15,7 +16,7 @@ class LLMClient:
|
|
15 |
Reads API key from environment and exposes `generate(prompt)`.
|
16 |
"""
|
17 |
@staticmethod
|
18 |
-
def generate(prompt: str, model: str = None, max_tokens: int = 512, **kwargs) -> str:
|
19 |
azure_api_key = os.getenv('AZURE_API_KEY')
|
20 |
azure_endpoint = os.getenv('AZURE_ENDPOINT')
|
21 |
azure_api_version = os.getenv('AZURE_API_VERSION')
|
@@ -39,24 +40,57 @@ class LLMClient:
|
|
39 |
**kwargs
|
40 |
)
|
41 |
text = resp.choices[0].message.content.strip()
|
42 |
-
|
|
|
43 |
except Exception as e:
|
44 |
logger.error(f'LLM generation failed: {e}')
|
45 |
raise
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
class OpenAIEmbedder:
|
49 |
"""
|
50 |
-
Wrapper around OpenAI Embeddings
|
51 |
-
|
52 |
-
embs = embedder.embed([str1, str2, ...])
|
53 |
"""
|
54 |
def __init__(self, model_name: str):
|
55 |
-
self.
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def embed(self, texts: List[str]) -> List[List[float]]:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
6 |
from typing import List
|
7 |
from openai import AzureOpenAI
|
8 |
from langchain_openai import AzureOpenAIEmbeddings
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
from src import logger
|
11 |
|
12 |
|
|
|
16 |
Reads API key from environment and exposes `generate(prompt)`.
|
17 |
"""
|
18 |
@staticmethod
|
19 |
+
def generate(prompt: str, model: str = None, max_tokens: int = 512, **kwargs) -> tuple[str, str]:
|
20 |
azure_api_key = os.getenv('AZURE_API_KEY')
|
21 |
azure_endpoint = os.getenv('AZURE_ENDPOINT')
|
22 |
azure_api_version = os.getenv('AZURE_API_VERSION')
|
|
|
40 |
**kwargs
|
41 |
)
|
42 |
text = resp.choices[0].message.content.strip()
|
43 |
+
finish_reason = resp.choices[0].finish_reason
|
44 |
+
return text, finish_reason
|
45 |
except Exception as e:
|
46 |
logger.error(f'LLM generation failed: {e}')
|
47 |
raise
|
48 |
|
49 |
|
50 |
+
class LocalEmbedder:
|
51 |
+
"""
|
52 |
+
Wrapper for a local SentenceTransformer model.
|
53 |
+
"""
|
54 |
+
def __init__(self, model_name: str):
|
55 |
+
self.model = SentenceTransformer(model_name)
|
56 |
+
logger.info(f"Initialized local embedder with model: {model_name}")
|
57 |
+
|
58 |
+
def embed(self, texts: List[str]) -> List[List[float]]:
|
59 |
+
"""Embeds a list of texts using the local SentenceTransformer model."""
|
60 |
+
try:
|
61 |
+
embeddings = self.model.encode(texts, show_progress_bar=False)
|
62 |
+
return embeddings.tolist()
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Local embedding failed: {e}")
|
65 |
+
raise
|
66 |
+
|
67 |
+
|
68 |
class OpenAIEmbedder:
|
69 |
"""
|
70 |
+
Wrapper around OpenAI and Azure OpenAI Embeddings.
|
71 |
+
Automatically uses Azure credentials if available, otherwise falls back to OpenAI.
|
|
|
72 |
"""
|
73 |
def __init__(self, model_name: str):
|
74 |
+
self.model_name = model_name
|
75 |
+
self.is_azure = os.getenv('AZURE_API_KEY') and os.getenv('AZURE_ENDPOINT')
|
76 |
+
|
77 |
+
if self.is_azure:
|
78 |
+
logger.info("Using Azure OpenAI for embeddings.")
|
79 |
+
self.embedder = AzureOpenAIEmbeddings(
|
80 |
+
model=self.model_name,
|
81 |
+
azure_deployment=os.getenv("AZURE_EMBEDDING_DEPLOYMENT"), # Assumes a deployment name is set
|
82 |
+
api_version=os.getenv("AZURE_API_VERSION")
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
logger.info("Using standard OpenAI for embeddings.")
|
86 |
+
# This part would need OPENAI_API_KEY to be set
|
87 |
+
from langchain_openai import OpenAIEmbeddings
|
88 |
+
self.embedder = OpenAIEmbeddings(model=self.model_name)
|
89 |
|
90 |
def embed(self, texts: List[str]) -> List[List[float]]:
|
91 |
+
"""Embeds a list of texts."""
|
92 |
+
try:
|
93 |
+
return self.embedder.embed_documents(texts)
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Embedding failed: {e}")
|
96 |
+
raise
|