Ayush083 commited on
Commit
e96fa19
·
verified ·
1 Parent(s): fb5cd92

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +405 -0
  2. dockerfile +24 -0
  3. main.py +16 -0
  4. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import json
3
+ import os
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import re
8
+ from transformers import pipeline
9
+ import torch
10
+
11
+ app = Flask(__name__)
12
+
13
+
14
+ class ImprovedVATIKAChatbot:
15
+ def __init__(self):
16
+ print("🚀 Initializing VATIKA Chatbot...")
17
+
18
+ # Load multilingual embedding model
19
+ self.embedding_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
20
+ print("✅ Embedding model loaded")
21
+
22
+ # Load QA model with better error handling
23
+ self.qa_pipeline = self.load_qa_model()
24
+
25
+ # Initialize data structures
26
+ self.contexts = []
27
+ self.context_embeddings = None
28
+ self.qa_pairs = [] # Store all Q&A pairs separately
29
+ self.qa_embeddings = None
30
+
31
+ # Load and process data
32
+ self.load_data()
33
+ print(f"✅ Loaded {len(self.contexts)} contexts and {len(self.qa_pairs)} Q&A pairs")
34
+
35
+ def load_qa_model(self):
36
+ """Load QA model with fallback options"""
37
+ models_to_try = [
38
+ "deepset/xlm-roberta-base-squad2",
39
+ "distilbert-base-multilingual-cased",
40
+ "deepset/minilm-uncased-squad2"
41
+ ]
42
+
43
+ for model_name in models_to_try:
44
+ try:
45
+ print(f"🔄 Trying to load QA model: {model_name}")
46
+ qa_pipeline = pipeline(
47
+ "question-answering",
48
+ model=model_name,
49
+ tokenizer=model_name,
50
+ device=-1 # Use CPU
51
+ )
52
+ print(f"✅ Successfully loaded: {model_name}")
53
+ return qa_pipeline
54
+ except Exception as e:
55
+ print(f"❌ Failed to load {model_name}: {e}")
56
+ continue
57
+
58
+ print("⚠️ Could not load any QA model, using fallback")
59
+ return None
60
+
61
+ def load_data(self):
62
+ """Load and preprocess VATIKA dataset with better error handling"""
63
+ try:
64
+ # Check if data files exist
65
+ train_file = 'data/train.json'
66
+ val_file = 'data/validation.json'
67
+
68
+ if not os.path.exists(train_file):
69
+ print("❌ Train file not found, creating sample data...")
70
+ self.create_sample_data()
71
+
72
+ # Load training data
73
+ with open(train_file, 'r', encoding='utf-8') as f:
74
+ train_data = json.load(f)
75
+
76
+ all_data = train_data.get('domains', [])
77
+
78
+ # Try to load validation data
79
+ if os.path.exists(val_file):
80
+ try:
81
+ with open(val_file, 'r', encoding='utf-8') as f:
82
+ val_data = json.load(f)
83
+ all_data.extend(val_data.get('domains', []))
84
+ print("✅ Validation data loaded")
85
+ except Exception as e:
86
+ print(f"⚠️ Could not load validation data: {e}")
87
+
88
+ # Process data
89
+ self.process_data(all_data)
90
+
91
+ except Exception as e:
92
+ print(f"❌ Error loading data: {e}")
93
+ self.create_fallback_data()
94
+
95
+ def process_data(self, domains_data):
96
+ """Process loaded data and create embeddings"""
97
+ all_contexts = []
98
+ all_qas = []
99
+
100
+ for domain_data in domains_data:
101
+ domain = domain_data.get('domain', 'unknown')
102
+
103
+ for context_data in domain_data.get('contexts', []):
104
+ context_text = context_data.get('context', '')
105
+ qas = context_data.get('qas', [])
106
+
107
+ if context_text.strip(): # Only add non-empty contexts
108
+ context_info = {
109
+ 'domain': domain,
110
+ 'context': context_text,
111
+ 'qas': qas
112
+ }
113
+ all_contexts.append(context_info)
114
+
115
+ # Extract Q&A pairs
116
+ for qa in qas:
117
+ question = qa.get('question', '').strip()
118
+ answer = qa.get('answer', '').strip()
119
+
120
+ if question and answer:
121
+ qa_info = {
122
+ 'question': question,
123
+ 'answer': answer,
124
+ 'domain': domain,
125
+ 'context': context_text
126
+ }
127
+ all_qas.append(qa_info)
128
+
129
+ self.contexts = all_contexts
130
+ self.qa_pairs = all_qas
131
+
132
+ # Create embeddings
133
+ if self.contexts:
134
+ print("🔄 Creating context embeddings...")
135
+ context_texts = [ctx['context'] for ctx in self.contexts]
136
+ self.context_embeddings = self.embedding_model.encode(context_texts, show_progress_bar=True)
137
+
138
+ if self.qa_pairs:
139
+ print("🔄 Creating Q&A embeddings...")
140
+ qa_questions = [qa['question'] for qa in self.qa_pairs]
141
+ self.qa_embeddings = self.embedding_model.encode(qa_questions, show_progress_bar=True)
142
+
143
+ def create_sample_data(self):
144
+ """Create sample data if original data is not available"""
145
+ sample_data = {
146
+ "domains": [
147
+ {
148
+ "domain": "varanasi_temples",
149
+ "contexts": [
150
+ {
151
+ "context": "काशी विश्वनाथ मंदिर वाराणसी का सबसे प्रसिद्ध और पवित्र मंदिर है। यह भगवान शिव को समर्पित है और गंगा नदी के पश्चिमी तट पर स्थित है। यह 12 ज्योतिर्लिंगों में से एक है और हिंदू धर्म में अत्यंत महत्वपूर्ण माना जाता है।",
152
+ "qas": [
153
+ {
154
+ "question": "काशी विश्वनाथ मंदिर कहाँ स्थित है?",
155
+ "answer": "काशी विश्वनाथ मंदिर वाराणसी में गंगा नदी के पश्चिमी तट पर स्थित है।"
156
+ },
157
+ {
158
+ "question": "काशी विश्वनाथ मंदिर किसे समर्पित है?",
159
+ "answer": "काशी विश्वनाथ मंदिर भगवान शिव को समर्पित है।"
160
+ },
161
+ {
162
+ "question": "क्या काशी विश्वनाथ ज्योतिर्लिंग है?",
163
+ "answer": "हाँ, काशी विश्वनाथ मंदिर 12 ज्योतिर्लिंगों में से एक है।"
164
+ }
165
+ ]
166
+ }
167
+ ]
168
+ },
169
+ {
170
+ "domain": "varanasi_ghats",
171
+ "contexts": [
172
+ {
173
+ "context": "दशाश्वमेध घाट वाराणसी का सबसे प्रसिद्ध और मुख्य घाट है। यहाँ प्रतिदिन शाम को भव्य गंगा आरती का आयोजन होता है। यह घाट अत्यंत पवित्र माना जाता है और हजारों श्रद्धालु और पर्यटक यहाँ आते हैं।",
174
+ "qas": [
175
+ {
176
+ "question": "दशाश्वमेध घाट पर आरती कब होती है?",
177
+ "answer": "दशाश्वमेध घाट पर प्रतिदिन शाम को गंगा आरती होती है।"
178
+ },
179
+ {
180
+ "question": "वाराणसी का सबसे प्रसिद्ध घाट कौन सा है?",
181
+ "answer": "दशाश्वमेध घाट वाराणसी का सबसे प्रसिद्ध घाट है।"
182
+ }
183
+ ]
184
+ }
185
+ ]
186
+ }
187
+ ]
188
+ }
189
+
190
+ os.makedirs('data', exist_ok=True)
191
+ with open('data/train.json', 'w', encoding='utf-8') as f:
192
+ json.dump(sample_data, f, ensure_ascii=False, indent=2)
193
+
194
+ print("✅ Sample data created")
195
+
196
+ def create_fallback_data(self):
197
+ """Create minimal fallback data"""
198
+ self.contexts = [{
199
+ 'domain': 'general',
200
+ 'context': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।',
201
+ 'qas': [{'question': 'वाराणसी क्या है?', 'answer': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।'}]
202
+ }]
203
+
204
+ self.qa_pairs = [{'question': 'वाराणसी क्या है?', 'answer': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।',
205
+ 'domain': 'general'}]
206
+
207
+ context_texts = [ctx['context'] for ctx in self.contexts]
208
+ self.context_embeddings = self.embedding_model.encode(context_texts)
209
+
210
+ qa_questions = [qa['question'] for qa in self.qa_pairs]
211
+ self.qa_embeddings = self.embedding_model.encode(qa_questions)
212
+
213
+ def find_best_qa_match(self, query, threshold=0.6):
214
+ """Find best matching Q&A pair"""
215
+ if not self.qa_pairs or self.qa_embeddings is None:
216
+ return None
217
+
218
+ query_embedding = self.embedding_model.encode([query])
219
+ similarities = cosine_similarity(query_embedding, self.qa_embeddings)[0]
220
+
221
+ best_idx = np.argmax(similarities)
222
+ best_score = similarities[best_idx]
223
+
224
+ if best_score > threshold:
225
+ return {
226
+ 'qa': self.qa_pairs[best_idx],
227
+ 'score': best_score
228
+ }
229
+
230
+ return None
231
+
232
+ def find_relevant_context(self, query, top_k=3, threshold=0.3):
233
+ """Find most relevant contexts"""
234
+ if not self.contexts or self.context_embeddings is None:
235
+ return []
236
+
237
+ query_embedding = self.embedding_model.encode([query])
238
+ similarities = cosine_similarity(query_embedding, self.context_embeddings)[0]
239
+
240
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
241
+
242
+ relevant_contexts = []
243
+ for idx in top_indices:
244
+ if similarities[idx] > threshold:
245
+ relevant_contexts.append({
246
+ 'context': self.contexts[idx],
247
+ 'similarity': similarities[idx]
248
+ })
249
+
250
+ return relevant_contexts
251
+
252
+ def generate_qa_answer(self, question, context):
253
+ """Generate answer using QA model"""
254
+ if not self.qa_pipeline:
255
+ return None
256
+
257
+ try:
258
+ # Truncate context if too long
259
+ max_context_length = 500
260
+ if len(context) > max_context_length:
261
+ context = context[:max_context_length] + "..."
262
+
263
+ result = self.qa_pipeline(question=question, context=context)
264
+
265
+ if result['score'] > 0.15: # Confidence threshold
266
+ return result['answer']
267
+
268
+ except Exception as e:
269
+ print(f"QA Pipeline error: {e}")
270
+
271
+ return None
272
+
273
+ def get_smart_fallback(self, query):
274
+ """Generate smart fallback responses"""
275
+ query_lower = query.lower()
276
+
277
+ # Keywords-based responses
278
+ responses = {
279
+ ('मंदिर',
280
+ 'temple'): "वाराणसी में काशी विश्वनाथ मंदिर, संकट मोचन हनुमान मंदिर, दुर्गा मंदिर जैसे प्रसिद्ध मंदिर हैं। किसी विशिष्ट मंदिर के बारे में पूछें।",
281
+ ('घाट',
282
+ 'ghat'): "वाराणसी में दशाश्वमेध घाट, मणिकर्णिका घाट, अस्सी घाट जैसे प्रसिद्ध घाट हैं। किसी विशिष्ट घाट के बारे में जानना चाहते हैं?",
283
+ ('आरती', 'aarti'): "गंगा आरती दशाश्वमेध घाट पर प्रतिदिन शाम को होती है। यह बहुत ही मनोहर और भव्य होती है।",
284
+ ('गंगा', 'ganga'): "गंगा नदी वाराणसी की जीवनधारा है। यहाँ लोग स्नान करते हैं और आरती देखते हैं।",
285
+ ('यात्रा', 'travel',
286
+ 'घूमना'): "वाराणसी में आप मंदिर, घाट, गलियाँ, और सांस्कृतिक स्थल देख सकते हैं। क्या विशिष्ट जानकारी चाहिए?"
287
+ }
288
+
289
+ for keywords, response in responses.items():
290
+ if any(keyword in query_lower for keyword in keywords):
291
+ return response
292
+
293
+ return "मुझे वाराणसी के बारे में आपका प्रश्न समझ नहीं आया। कृपया मंदिर, घाट, आरती, या यात्रा के बारे में पूछें।"
294
+
295
+ def process_query(self, query):
296
+ """Main query processing function"""
297
+ if not query.strip():
298
+ return "कृपया अपना प्रश्न पूछें।"
299
+
300
+ print(f"🔍 Processing query: {query}")
301
+
302
+ # Step 1: Try to find direct Q&A match
303
+ qa_match = self.find_best_qa_match(query)
304
+ if qa_match:
305
+ print(f"✅ Found Q&A match with score: {qa_match['score']:.3f}")
306
+ return qa_match['qa']['answer']
307
+
308
+ # Step 2: Find relevant contexts
309
+ relevant_contexts = self.find_relevant_context(query)
310
+
311
+ if relevant_contexts:
312
+ print(f"✅ Found {len(relevant_contexts)} relevant contexts")
313
+
314
+ # Step 3: Try QA model on best context
315
+ best_context = relevant_contexts[0]['context']
316
+ qa_answer = self.generate_qa_answer(query, best_context['context'])
317
+
318
+ if qa_answer:
319
+ return qa_answer
320
+
321
+ # Step 4: Check for direct Q&As in the context
322
+ for qa in best_context['qas']:
323
+ if self.is_similar_question(query, qa['question']):
324
+ return qa['answer']
325
+
326
+ # Step 5: Smart fallback
327
+ return self.get_smart_fallback(query)
328
+
329
+ def is_similar_question(self, q1, q2, threshold=0.7):
330
+ """Check if two questions are similar"""
331
+ try:
332
+ embeddings = self.embedding_model.encode([q1, q2])
333
+ similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
334
+ return similarity > threshold
335
+ except:
336
+ return False
337
+
338
+
339
+ # Initialize improved chatbot
340
+ chatbot = ImprovedVATIKAChatbot()
341
+
342
+
343
+ @app.route('/')
344
+ def home():
345
+ return render_template('index.html')
346
+
347
+
348
+ @app.route('/chat', methods=['POST'])
349
+ def chat():
350
+ try:
351
+ data = request.get_json()
352
+ user_message = data.get('message', '').strip()
353
+
354
+ if not user_message:
355
+ return jsonify({'error': 'कृपया कोई संदेश भेजें'}), 400
356
+
357
+ # Process the query
358
+ response = chatbot.process_query(user_message)
359
+
360
+ # Add debug info in development
361
+ debug_info = {
362
+ 'total_contexts': len(chatbot.contexts),
363
+ 'total_qas': len(chatbot.qa_pairs),
364
+ 'model_loaded': chatbot.qa_pipeline is not None
365
+ }
366
+
367
+ return jsonify({
368
+ 'response': response,
369
+ 'status': 'success',
370
+ 'debug': debug_info if app.debug else None
371
+ })
372
+
373
+ except Exception as e:
374
+ print(f"❌ Chat error: {e}")
375
+ return jsonify({
376
+ 'error': f'कुछ गलती हुई है: {str(e)}',
377
+ 'status': 'error'
378
+ }), 500
379
+
380
+
381
+ @app.route('/health')
382
+ def health():
383
+ return jsonify({
384
+ 'status': 'healthy',
385
+ 'contexts_loaded': len(chatbot.contexts),
386
+ 'qas_loaded': len(chatbot.qa_pairs),
387
+ 'embeddings_ready': chatbot.context_embeddings is not None,
388
+ 'qa_model_loaded': chatbot.qa_pipeline is not None
389
+ })
390
+
391
+
392
+ @app.route('/debug')
393
+ def debug():
394
+ """Debug endpoint to check data"""
395
+ return jsonify({
396
+ 'contexts': len(chatbot.contexts),
397
+ 'qa_pairs': len(chatbot.qa_pairs),
398
+ 'sample_context': chatbot.contexts[0] if chatbot.contexts else None,
399
+ 'sample_qa': chatbot.qa_pairs[0] if chatbot.qa_pairs else None
400
+ })
401
+
402
+ if __name__ == "__main__":
403
+ # HF Spaces requirement: port 7860
404
+ port = int(os.environ.get("PORT", 7860))
405
+ app.run(host="0.0.0.0", port=port, debug=False)
dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ gcc \
8
+ g++ \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements
12
+ COPY requirements.txt .
13
+
14
+ # Install Python packages
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy app files
18
+ COPY . .
19
+
20
+ # Expose port 7860 (HF Spaces requirement)
21
+ EXPOSE 7860
22
+
23
+ # Run Flask app
24
+ CMD ["python", "app.py"]
main.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a sample Python script.
2
+
3
+ # Press Shift+F10 to execute it or replace it with your code.
4
+ # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
5
+
6
+
7
+ def print_hi(name):
8
+ # Use a breakpoint in the code line below to debug your script.
9
+ print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
10
+
11
+
12
+ # Press the green button in the gutter to run the script.
13
+ if __name__ == '__main__':
14
+ print_hi('PyCharm')
15
+
16
+ # See PyCharm help at https://www.jetbrains.com/help/pycharm/
requirements.txt ADDED
Binary file (374 Bytes). View file