Spaces:
Running
Running
import os | |
import tempfile | |
# Fix cache permissions for Hugging Face Spaces - MUST be at the very top | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' | |
os.environ['HF_HOME'] = '/tmp/huggingface_cache' | |
os.environ['TORCH_HOME'] = '/tmp/torch_cache' | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/sentence_transformers' | |
# Create cache directories | |
cache_dirs = ['/tmp/transformers_cache', '/tmp/huggingface_cache', '/tmp/torch_cache', '/tmp/sentence_transformers'] | |
for cache_dir in cache_dirs: | |
os.makedirs(cache_dir, exist_ok=True) | |
from flask import Flask, render_template, request, jsonify | |
import json | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
import re | |
from transformers import pipeline | |
import torch | |
import nltk | |
from collections import Counter | |
import unicodedata | |
# Download required NLTK data | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('wordnet', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
except: | |
print("Warning: Could not download NLTK data") | |
app = Flask(__name__) | |
class EvaluationMetrics: | |
"""Class to calculate F1, BLEU, and ROUGE-L scores""" | |
def normalize_text(text): | |
"""Normalize text for comparison""" | |
# Remove extra whitespace and normalize unicode | |
text = unicodedata.normalize('NFKD', text) | |
text = ' '.join(text.split()) | |
return text.lower() | |
def tokenize_text(text): | |
"""Simple tokenization for Hindi/English mixed text""" | |
# Simple tokenization that works for both Hindi and English | |
text = re.sub(r'[^\w\s]', ' ', text) | |
return text.split() | |
def calculate_f1_score(predicted, reference): | |
"""Calculate F1 score between predicted and reference text""" | |
try: | |
pred_tokens = set(EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(predicted))) | |
ref_tokens = set(EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(reference))) | |
if len(pred_tokens) == 0 and len(ref_tokens) == 0: | |
return 1.0 | |
if len(pred_tokens) == 0 or len(ref_tokens) == 0: | |
return 0.0 | |
# Calculate intersection | |
common_tokens = pred_tokens.intersection(ref_tokens) | |
if len(common_tokens) == 0: | |
return 0.0 | |
precision = len(common_tokens) / len(pred_tokens) | |
recall = len(common_tokens) / len(ref_tokens) | |
f1 = 2 * (precision * recall) / (precision + recall) | |
return float(f1) # Convert to Python float | |
except: | |
return 0.0 | |
def calculate_bleu_score(predicted, reference): | |
"""Calculate BLEU score""" | |
try: | |
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
pred_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(predicted)) | |
ref_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(reference)) | |
if len(pred_tokens) == 0 or len(ref_tokens) == 0: | |
return 0.0 | |
# Use smoothing function to handle zero n-grams | |
smoothing = SmoothingFunction() | |
bleu_score = sentence_bleu( | |
[ref_tokens], | |
pred_tokens, | |
smoothing_function=smoothing.method1 | |
) | |
return float(bleu_score) # Convert to Python float | |
except: | |
# Fallback BLEU calculation if NLTK fails | |
return EvaluationMetrics.simple_bleu(predicted, reference) | |
def simple_bleu(predicted, reference): | |
"""Simple BLEU calculation fallback""" | |
try: | |
pred_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(predicted)) | |
ref_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(reference)) | |
if len(pred_tokens) == 0 or len(ref_tokens) == 0: | |
return 0.0 | |
# Calculate 1-gram precision (simplified BLEU) | |
pred_counts = Counter(pred_tokens) | |
ref_counts = Counter(ref_tokens) | |
overlap = sum(min(pred_counts[token], ref_counts[token]) for token in pred_counts) | |
precision = overlap / len(pred_tokens) if len(pred_tokens) > 0 else 0 | |
return float(precision) # Convert to Python float | |
except: | |
return 0.0 | |
def lcs_length(x, y): | |
"""Calculate Longest Common Subsequence length""" | |
m, n = len(x), len(y) | |
dp = [[0] * (n + 1) for _ in range(m + 1)] | |
for i in range(1, m + 1): | |
for j in range(1, n + 1): | |
if x[i-1] == y[j-1]: | |
dp[i][j] = dp[i-1][j-1] + 1 | |
else: | |
dp[i][j] = max(dp[i-1][j], dp[i][j-1]) | |
return dp[m][n] | |
def calculate_rouge_l(predicted, reference): | |
"""Calculate ROUGE-L score""" | |
try: | |
pred_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(predicted)) | |
ref_tokens = EvaluationMetrics.tokenize_text(EvaluationMetrics.normalize_text(reference)) | |
if len(pred_tokens) == 0 and len(ref_tokens) == 0: | |
return 1.0 | |
if len(pred_tokens) == 0 or len(ref_tokens) == 0: | |
return 0.0 | |
lcs_len = EvaluationMetrics.lcs_length(pred_tokens, ref_tokens) | |
if lcs_len == 0: | |
return 0.0 | |
precision = lcs_len / len(pred_tokens) | |
recall = lcs_len / len(ref_tokens) | |
if precision + recall == 0: | |
return 0.0 | |
rouge_l = 2 * precision * recall / (precision + recall) | |
return float(rouge_l) # Convert to Python float | |
except: | |
return 0.0 | |
def calculate_all_scores(predicted, reference): | |
"""Calculate all evaluation metrics""" | |
if not predicted or not reference: | |
return { | |
'f1_score': 0.0, | |
'bleu_score': 0.0, | |
'rouge_l_score': 0.0 | |
} | |
return { | |
'f1_score': round(EvaluationMetrics.calculate_f1_score(predicted, reference), 4), | |
'bleu_score': round(EvaluationMetrics.calculate_bleu_score(predicted, reference), 4), | |
'rouge_l_score': round(EvaluationMetrics.calculate_rouge_l(predicted, reference), 4) | |
} | |
class ImprovedVATIKAChatbot: | |
def __init__(self): | |
print("🚀 Initializing KashiVani Chatbot...") | |
# Load multilingual embedding model | |
self.embedding_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
print("✅ Embedding model loaded") | |
# Load QA model with better error handling | |
self.qa_pipeline = self.load_qa_model() | |
# Initialize data structures | |
self.contexts = [] | |
self.context_embeddings = None | |
self.qa_pairs = [] # Store all Q&A pairs separately | |
self.qa_embeddings = None | |
# Initialize evaluation metrics | |
self.evaluation_metrics = EvaluationMetrics() | |
# Load and process data | |
self.load_data() | |
print(f"✅ Loaded {len(self.contexts)} contexts and {len(self.qa_pairs)} Q&A pairs") | |
def load_qa_model(self): | |
"""Load QA model with fallback options""" | |
models_to_try = [ | |
"deepset/xlm-roberta-base-squad2", | |
"distilbert-base-multilingual-cased", | |
"deepset/minilm-uncased-squad2" | |
] | |
for model_name in models_to_try: | |
try: | |
print(f"🔄 Trying to load QA model: {model_name}") | |
qa_pipeline = pipeline( | |
"question-answering", | |
model=model_name, | |
tokenizer=model_name, | |
device=-1 # Use CPU | |
) | |
print(f"✅ Successfully loaded: {model_name}") | |
return qa_pipeline | |
except Exception as e: | |
print(f"❌ Failed to load {model_name}: {e}") | |
continue | |
print("⚠️ Could not load any QA model, using fallback") | |
return None | |
def load_data(self): | |
"""Load and preprocess VATIKA dataset with better error handling""" | |
try: | |
# Check if data files exist | |
train_file = 'data/train.json' | |
val_file = 'data/validation.json' | |
if not os.path.exists(train_file): | |
print("❌ Train file not found, creating sample data...") | |
self.create_sample_data() | |
# Load training data | |
with open(train_file, 'r', encoding='utf-8') as f: | |
train_data = json.load(f) | |
all_data = train_data.get('domains', []) | |
# Try to load validation data | |
if os.path.exists(val_file): | |
try: | |
with open(val_file, 'r', encoding='utf-8') as f: | |
val_data = json.load(f) | |
all_data.extend(val_data.get('domains', [])) | |
print("✅ Validation data loaded") | |
except Exception as e: | |
print(f"⚠️-- Could not load validation data: {e}") | |
# Process data | |
self.process_data(all_data) | |
except Exception as e: | |
print(f"❌ Error loading data: {e}") | |
self.create_fallback_data() | |
def process_data(self, domains_data): | |
"""Process loaded data and create embeddings""" | |
all_contexts = [] | |
all_qas = [] | |
for domain_data in domains_data: | |
domain = domain_data.get('domain', 'unknown') | |
for context_data in domain_data.get('contexts', []): | |
context_text = context_data.get('context', '') | |
qas = context_data.get('qas', []) | |
if context_text.strip(): # Only add non-empty contexts | |
context_info = { | |
'domain': domain, | |
'context': context_text, | |
'qas': qas | |
} | |
all_contexts.append(context_info) | |
# Extract Q&A pairs | |
for qa in qas: | |
question = qa.get('question', '').strip() | |
answer = qa.get('answer', '').strip() | |
if question and answer: | |
qa_info = { | |
'question': question, | |
'answer': answer, | |
'domain': domain, | |
'context': context_text | |
} | |
all_qas.append(qa_info) | |
self.contexts = all_contexts | |
self.qa_pairs = all_qas | |
# Create embeddings | |
if self.contexts: | |
print("🔄 Creating context embeddings...") | |
context_texts = [ctx['context'] for ctx in self.contexts] | |
self.context_embeddings = self.embedding_model.encode(context_texts, show_progress_bar=True) | |
if self.qa_pairs: | |
print("🔄 Creating Q&A embeddings...") | |
qa_questions = [qa['question'] for qa in self.qa_pairs] | |
self.qa_embeddings = self.embedding_model.encode(qa_questions, show_progress_bar=True) | |
def create_sample_data(self): | |
"""Create sample data if original data is not available""" | |
sample_data = { | |
"domains": [ | |
{ | |
"domain": "varanasi_temples", | |
"contexts": [ | |
{ | |
"context": "काशी विश्वनाथ मंदिर वाराणसी का सबसे प्रसिद्ध और पवित्र मंदिर है। यह भगवान शिव को समर्पित है और गंगा नदी के पश्चिमी तट पर स्थित है। यह 12 ज्योतिर्लिंगों में से एक है और हिंदू धर्म में अत्यंत महत्वपूर्ण माना जाता है।", | |
"qas": [ | |
{ | |
"question": "काशी विश्वनाथ मंदिर कहाँ स्थित है?", | |
"answer": "काशी विश्वनाथ मंदिर वाराणसी में गंगा नदी के पश्चिमी तट पर स्थित है।" | |
}, | |
{ | |
"question": "काशी विश्वनाथ मंदिर किसे समर्पित है?", | |
"answer": "काशी विश्वनाथ मंदिर भगवान शिव को समर्पित है।" | |
}, | |
{ | |
"question": "क्या काशी विश्वनाथ ज्योतिर्लिंग है?", | |
"answer": "हाँ, काशी विश्वनाथ मंदिर 12 ज्योतिर्लिंगों में से एक है।" | |
} | |
] | |
} | |
] | |
}, | |
{ | |
"domain": "varanasi_ghats", | |
"contexts": [ | |
{ | |
"context": "दशाश्वमेध घाट वाराणसी का सबसे प्रसिद्ध और मुख्य घाट है। यहाँ प्रतिदिन शाम को भव्य गंगा आरती का आयोजन होता है। यह घाट अत्यंत पवित्र माना जाता है और हजारों श्रद्धालु और पर्यटक यहाँ आते हैं।", | |
"qas": [ | |
{ | |
"question": "दशाश्वमेध घाट पर आरती कब होती है?", | |
"answer": "दशाश्वमेध घाट पर प्रतिदिन शाम को गंगा आरती होती है।" | |
}, | |
{ | |
"question": "वाराणसी का सबसे प्रसिद्ध घाट कौन सा है?", | |
"answer": "दशाश्वमेध घाट वाराणसी का सबसे प्रसिद्ध घाट है।" | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
os.makedirs('data', exist_ok=True) | |
with open('data/train.json', 'w', encoding='utf-8') as f: | |
json.dump(sample_data, f, ensure_ascii=False, indent=2) | |
print("✅ Sample data created") | |
def create_fallback_data(self): | |
"""Create minimal fallback data""" | |
self.contexts = [{ | |
'domain': 'general', | |
'context': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।', | |
'qas': [{'question': 'वाराणसी क्या है?', 'answer': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।'}] | |
}] | |
self.qa_pairs = [{'question': 'वाराणसी क्या है?', 'answer': 'वाराणसी भारत का एक प्राचीन और पवित्र शहर है।', | |
'domain': 'general'}] | |
context_texts = [ctx['context'] for ctx in self.contexts] | |
self.context_embeddings = self.embedding_model.encode(context_texts) | |
qa_questions = [qa['question'] for qa in self.qa_pairs] | |
self.qa_embeddings = self.embedding_model.encode(qa_questions) | |
def find_best_qa_match(self, query, threshold=0.6): | |
"""Find best matching Q&A pair""" | |
if not self.qa_pairs or self.qa_embeddings is None: | |
return None | |
query_embedding = self.embedding_model.encode([query]) | |
similarities = cosine_similarity(query_embedding, self.qa_embeddings)[0] | |
best_idx = np.argmax(similarities) | |
best_score = similarities[best_idx] | |
if best_score > threshold: | |
return { | |
'qa': self.qa_pairs[best_idx], | |
'score': float(best_score) # Convert to Python float | |
} | |
return None | |
def find_relevant_context(self, query, top_k=3, threshold=0.3): | |
"""Find most relevant contexts""" | |
if not self.contexts or self.context_embeddings is None: | |
return [] | |
query_embedding = self.embedding_model.encode([query]) | |
similarities = cosine_similarity(query_embedding, self.context_embeddings)[0] | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
relevant_contexts = [] | |
for idx in top_indices: | |
if similarities[idx] > threshold: | |
relevant_contexts.append({ | |
'context': self.contexts[idx], | |
'similarity': float(similarities[idx]) # Convert to Python float | |
}) | |
return relevant_contexts | |
def generate_qa_answer(self, question, context): | |
"""Generate answer using QA model""" | |
if not self.qa_pipeline: | |
return None | |
try: | |
# Truncate context if too long | |
max_context_length = 500 | |
if len(context) > max_context_length: | |
context = context[:max_context_length] + "..." | |
result = self.qa_pipeline(question=question, context=context) | |
if result['score'] > 0.15: # Confidence threshold | |
return result['answer'] | |
except Exception as e: | |
print(f"QA Pipeline error: {e}") | |
return None | |
def get_smart_fallback(self, query): | |
"""Generate smart fallback responses""" | |
query_lower = query.lower() | |
# Keywords-based responses | |
responses = { | |
('मंदिर', | |
'temple'): "वाराणसी में काशी विश्वनाथ मंदिर, संकट मोचन हनुमान मंदिर, दुर्गा मंदिर जैसे प्रसिद्ध मंदिर हैं। किसी विशिष्ट मंदिर के बारे में पूछें।", | |
('घाट', | |
'ghat'): "वाराणसी में दशाश्वमेध घाट, मणिकर्णिका घाट, अस्सी घाट जैसे प्रसिद्ध घाट हैं। किसी विशिष्ट घाट के बारे में जानना चाहते हैं?", | |
('आरती', 'aarti'): "गंगा आरती दशाश्वमेध घाट पर प्रतिदिन शाम को होती है। यह बहुत ही मनोहर और भव्य होती है।", | |
('गंगा', 'ganga'): "गंगा नदी वाराणसी की जीवनधारा है। यहाँ लोग स्नान करते हैं और आरती देखते हैं।", | |
('यात्रा', 'travel', | |
'घूमना'): "वाराणसी में आप मंदिर, घाट, गलियाँ, और सांस्कृतिक स्थल देख सकते हैं। क्या विशिष्ट जानकारी चाहिए?" | |
} | |
for keywords, response in responses.items(): | |
if any(keyword in query_lower for keyword in keywords): | |
return response | |
return "मुझे वाराणसी के बारे में आपका प्रश्न समझ नहीं आया। कृपया मंदिर, घाट, आरती, या यात्रा के बारे में पूछें।" | |
def process_query(self, query): | |
"""Main query processing function with evaluation metrics""" | |
if not query.strip(): | |
return { | |
'answer': "कृपया अपना प्रश्न पूछें।", | |
'reference_answer': None, | |
'evaluation_scores': None, | |
'response_type': 'empty_query' | |
} | |
print(f"🔍 Processing query: {query}") | |
# Step 1: Try to find direct Q&A match | |
qa_match = self.find_best_qa_match(query) | |
if qa_match: | |
print(f"✅ Found Q&A match with score: {qa_match['score']:.3f}") | |
# For direct Q&A matches, we have the reference answer | |
predicted_answer = qa_match['qa']['answer'] | |
reference_answer = qa_match['qa']['answer'] # Same as predicted for exact matches | |
# Calculate evaluation scores | |
evaluation_scores = self.evaluation_metrics.calculate_all_scores( | |
predicted_answer, reference_answer | |
) | |
return { | |
'answer': predicted_answer, | |
'reference_answer': reference_answer, | |
'evaluation_scores': evaluation_scores, | |
'response_type': 'direct_qa_match', | |
'similarity_score': qa_match['score'] | |
} | |
# Step 2: Find relevant contexts | |
relevant_contexts = self.find_relevant_context(query) | |
reference_answer = None | |
if relevant_contexts: | |
print(f"✅ Found {len(relevant_contexts)} relevant contexts") | |
# Step 3: Try QA model on best context | |
best_context = relevant_contexts[0]['context'] | |
qa_answer = self.generate_qa_answer(query, best_context['context']) | |
if qa_answer: | |
# Try to find a reference answer from the context's QAs | |
reference_answer = self.find_reference_answer(query, best_context['qas']) | |
evaluation_scores = self.evaluation_metrics.calculate_all_scores( | |
qa_answer, reference_answer | |
) if reference_answer else None | |
return { | |
'answer': qa_answer, | |
'reference_answer': reference_answer, | |
'evaluation_scores': evaluation_scores, | |
'response_type': 'qa_model_generated', | |
'context_similarity': relevant_contexts[0]['similarity'] | |
} | |
# Step 4: Check for direct Q&As in the context | |
for qa in best_context['qas']: | |
if self.is_similar_question(query, qa['question']): | |
predicted_answer = qa['answer'] | |
reference_answer = qa['answer'] | |
evaluation_scores = self.evaluation_metrics.calculate_all_scores( | |
predicted_answer, reference_answer | |
) | |
return { | |
'answer': predicted_answer, | |
'reference_answer': reference_answer, | |
'evaluation_scores': evaluation_scores, | |
'response_type': 'context_qa_match', | |
'context_similarity': relevant_contexts[0]['similarity'] | |
} | |
# Step 5: Smart fallback | |
fallback_answer = self.get_smart_fallback(query) | |
return { | |
'answer': fallback_answer, | |
'reference_answer': None, | |
'evaluation_scores': None, | |
'response_type': 'fallback' | |
} | |
def find_reference_answer(self, query, qas): | |
"""Find the most similar question's answer as reference""" | |
if not qas: | |
return None | |
best_similarity = 0 | |
best_answer = None | |
for qa in qas: | |
if self.is_similar_question(query, qa['question'], threshold=0.5): | |
# Calculate similarity score | |
try: | |
embeddings = self.embedding_model.encode([query, qa['question']]) | |
similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0] | |
if similarity > best_similarity: | |
best_similarity = similarity | |
best_answer = qa['answer'] | |
except: | |
continue | |
return best_answer | |
def is_similar_question(self, q1, q2, threshold=0.7): | |
"""Check if two questions are similar""" | |
try: | |
embeddings = self.embedding_model.encode([q1, q2]) | |
similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0] | |
return similarity > threshold | |
except: | |
return False | |
# Initialize improved chatbot | |
chatbot = ImprovedVATIKAChatbot() | |
def convert_to_serializable(obj): | |
"""Convert numpy types to Python native types for JSON serialization""" | |
if isinstance(obj, np.floating): | |
return float(obj) | |
elif isinstance(obj, np.integer): | |
return int(obj) | |
elif isinstance(obj, np.ndarray): | |
return obj.tolist() | |
elif isinstance(obj, dict): | |
return {key: convert_to_serializable(value) for key, value in obj.items()} | |
elif isinstance(obj, list): | |
return [convert_to_serializable(item) for item in obj] | |
else: | |
return obj | |
def home(): | |
return render_template('index.html') | |
def chat(): | |
try: | |
data = request.get_json() | |
user_message = data.get('message', '').strip() | |
if not user_message: | |
return jsonify({'error': 'कृपया कोई संदेश भेजें'}), 400 | |
# Process the query | |
result = chatbot.process_query(user_message) | |
# Convert all values to JSON serializable format | |
result = convert_to_serializable(result) | |
# Prepare response | |
response_data = { | |
'response': result['answer'], | |
'status': 'success', | |
'response_type': result['response_type'], | |
'evaluation_scores': result['evaluation_scores'], | |
'reference_answer': result['reference_answer'], | |
} | |
# Add additional info based on response type | |
if result['response_type'] == 'direct_qa_match': | |
response_data['similarity_score'] = result['similarity_score'] | |
elif result['response_type'] in ['qa_model_generated', 'context_qa_match']: | |
response_data['context_similarity'] = result['context_similarity'] | |
# Add debug info in development | |
if app.debug: | |
response_data['debug'] = { | |
'total_contexts': len(chatbot.contexts), | |
'total_qas': len(chatbot.qa_pairs), | |
'model_loaded': chatbot.qa_pipeline is not None | |
} | |
# Final conversion to ensure everything is serializable | |
response_data = convert_to_serializable(response_data) | |
return jsonify(response_data) | |
except Exception as e: | |
print(f"❌ Chat error: {e}") | |
import traceback | |
traceback.print_exc() | |
return jsonify({ | |
'error': f'कुछ गलती हुई है: {str(e)}', | |
'status': 'error' | |
}), 500 | |
def health(): | |
return jsonify({ | |
'status': 'healthy', | |
'contexts_loaded': len(chatbot.contexts), | |
'qas_loaded': len(chatbot.qa_pairs), | |
'embeddings_ready': chatbot.context_embeddings is not None, | |
'qa_model_loaded': chatbot.qa_pipeline is not None | |
}) | |
def debug(): | |
"""Debug endpoint to check data""" | |
return jsonify({ | |
'contexts': len(chatbot.contexts), | |
'qa_pairs': len(chatbot.qa_pairs), | |
'sample_context': chatbot.contexts[0] if chatbot.contexts else None, | |
'sample_qa': chatbot.qa_pairs[0] if chatbot.qa_pairs else None | |
}) | |
# Add this route AFTER your existing /chat route and BEFORE if __name__ == "__main__": | |
def api_predict(): | |
""" | |
API endpoint for external applications | |
Expected JSON: {"question": "your question", "context": "optional context"} | |
Returns: {"answer": "response", "success": true, "evaluation_scores": {...}} | |
""" | |
try: | |
# Handle both JSON and form data | |
if request.is_json: | |
data = request.get_json() | |
else: | |
data = request.form.to_dict() | |
# Get question from different possible keys | |
question = data.get('question') or data.get('message') or data.get('query', '').strip() | |
provided_context = data.get('context', '').strip() | |
if not question: | |
return jsonify({ | |
'success': False, | |
'error': 'कृपया प्रश्न भेजें', | |
'answer': '' | |
}), 400 | |
# If context is provided, use it directly for QA | |
if provided_context: | |
# Use the provided context for QA | |
if chatbot.qa_pipeline: | |
try: | |
# Truncate context if too long | |
max_context_length = 500 | |
if len(provided_context) > max_context_length: | |
provided_context = provided_context[:max_context_length] + "..." | |
result = chatbot.qa_pipeline(question=question, context=provided_context) | |
if result['score'] > 0.15: # Confidence threshold | |
response_data = { | |
'success': True, | |
'answer': result['answer'], | |
'confidence': float(result['score']), | |
'response_type': 'context_based_qa', | |
'evaluation_scores': None, | |
'reference_answer': None | |
} | |
return jsonify(convert_to_serializable(response_data)) | |
except Exception as e: | |
print(f"Context QA error: {e}") | |
# Use existing chatbot logic | |
result = chatbot.process_query(question) | |
# Convert to serializable format | |
result = convert_to_serializable(result) | |
# Prepare API response | |
response_data = { | |
'success': True, | |
'answer': result['answer'], | |
'response_type': result['response_type'], | |
'evaluation_scores': result['evaluation_scores'], | |
'reference_answer': result['reference_answer'] | |
} | |
# Add additional info based on response type | |
if result['response_type'] == 'direct_qa_match': | |
response_data['similarity_score'] = result['similarity_score'] | |
elif result['response_type'] in ['qa_model_generated', 'context_qa_match']: | |
response_data['context_similarity'] = result['context_similarity'] | |
return jsonify(response_data) | |
except Exception as e: | |
print(f"❌ API error: {e}") | |
import traceback | |
traceback.print_exc() | |
return jsonify({ | |
'success': False, | |
'error': f'कुछ गलती हुई है: {str(e)}', | |
'answer': '' | |
}), 500 | |
def api_simple(): | |
""" | |
Simple API endpoint for basic question-answering | |
POST: JSON {"question": "your question"} | |
GET: ?question=your+question | |
Returns: {"answer": "response"} | |
""" | |
try: | |
if request.method == 'POST': | |
if request.is_json: | |
data = request.get_json() | |
question = data.get('question', '').strip() | |
else: | |
question = request.form.get('question', '').strip() | |
else: # GET | |
question = request.args.get('question', '').strip() | |
if not question: | |
return jsonify({ | |
'answer': 'कृपया प्रश्न भेजें' | |
}), 400 | |
# Process the query | |
result = chatbot.process_query(question) | |
# Simple response | |
return jsonify({ | |
'answer': result['answer'] | |
}) | |
except Exception as e: | |
print(f"❌ Simple API error: {e}") | |
return jsonify({ | |
'answer': f'कुछ गलती हुई है: {str(e)}' | |
}), 500 | |
if __name__ == "__main__": | |
# HF Spaces requirement: port 7860 | |
port = int(os.environ.get("PORT", 7860)) | |
app.run(host="0.0.0.0", port=port, debug=False) |