Spaces:
Running
Running
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Any, Tuple | |
import re | |
import json | |
from collections import defaultdict | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class VATIKARetriever: | |
""" | |
Advanced retrieval system for VATIKA dataset | |
Implements multiple retrieval strategies for better context matching | |
""" | |
def __init__(self, embedding_model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): | |
self.embedding_model = SentenceTransformer(embedding_model_name) | |
self.contexts = [] | |
self.context_embeddings = None | |
self.domain_embeddings = {} | |
self.keyword_index = defaultdict(set) | |
self.question_embeddings = [] | |
self.qa_pairs = [] | |
def load_and_index_data(self, contexts_data: List[Dict[str, Any]]): | |
""" | |
Load data and create multiple indexes for efficient retrieval | |
""" | |
logger.info("Loading and indexing VATIKA data...") | |
self.contexts = contexts_data | |
self._create_context_embeddings() | |
self._create_domain_embeddings() | |
self._create_keyword_index() | |
self._create_qa_embeddings() | |
logger.info(f"Indexed {len(self.contexts)} contexts successfully") | |
def _create_context_embeddings(self): | |
"""Create embeddings for all contexts""" | |
context_texts = [ctx['context'] for ctx in self.contexts] | |
self.context_embeddings = self.embedding_model.encode( | |
context_texts, | |
show_progress_bar=True, | |
batch_size=32 | |
) | |
logger.info(f"Created embeddings for {len(context_texts)} contexts") | |
def _create_domain_embeddings(self): | |
"""Create domain-specific embeddings""" | |
domain_contexts = defaultdict(list) | |
for ctx in self.contexts: | |
domain_contexts[ctx['domain']].append(ctx['context']) | |
for domain, contexts in domain_contexts.items(): | |
# Combine all contexts for a domain | |
combined_context = " ".join(contexts) | |
domain_embedding = self.embedding_model.encode([combined_context]) | |
self.domain_embeddings[domain] = domain_embedding[0] | |
logger.info(f"Created domain embeddings for {len(self.domain_embeddings)} domains") | |
def _create_keyword_index(self): | |
"""Create keyword-based index for fast lookups""" | |
hindi_keywords = { | |
'घाट': ['ghat', 'घाट', 'तट'], | |
'मंदिर': ['temple', 'मंदिर', 'देवालय'], | |
'आरती': ['aarti', 'आरती', 'पूजा'], | |
'भोजन': ['food', 'भोजन', 'खाना'], | |
'होटल': ['hotel', 'होटल', 'आवास'], | |
'यात्रा': ['travel', 'यात्रा', 'सफर'], | |
'समय': ['time', 'समय', 'टाइम'], | |
'दूरी': ['distance', 'दूरी', 'फासला'], | |
'कैसे': ['how', 'कैसे', 'कैसे'], | |
'कहां': ['where', 'कहां', 'कहाँ'], | |
'क्या': ['what', 'क्या'], | |
'कब': ['when', 'कब'], | |
'कितना': ['how much', 'कितना', 'कितनी'], | |
} | |
for idx, ctx in enumerate(self.contexts): | |
context_text = ctx['context'].lower() | |
domain = ctx['domain'] | |
# Index by keywords | |
for keyword, variants in hindi_keywords.items(): | |
for variant in variants: | |
if variant in context_text: | |
self.keyword_index[keyword].add(idx) | |
# Index by domain | |
self.keyword_index[domain].add(idx) | |
# Index QA pairs | |
for qa in ctx['qas']: | |
question_text = qa['question'].lower() | |
for keyword, variants in hindi_keywords.items(): | |
for variant in variants: | |
if variant in question_text: | |
self.keyword_index[keyword].add(idx) | |
def _create_qa_embeddings(self): | |
"""Create embeddings for all Q&A pairs for direct matching""" | |
for ctx_idx, ctx in enumerate(self.contexts): | |
for qa in ctx['qas']: | |
qa_text = qa['question'] + " " + qa['answer'] | |
self.qa_pairs.append({ | |
'context_idx': ctx_idx, | |
'qa': qa, | |
'combined_text': qa_text, | |
'domain': ctx['domain'] | |
}) | |
if self.qa_pairs: | |
qa_texts = [qa['combined_text'] for qa in self.qa_pairs] | |
self.question_embeddings = self.embedding_model.encode( | |
qa_texts, | |
show_progress_bar=True, | |
batch_size=32 | |
) | |
logger.info(f"Created embeddings for {len(self.qa_pairs)} Q&A pairs") | |
def retrieve_contexts(self, query: str, top_k: int = 5, strategy: str = 'hybrid') -> List[Dict[str, Any]]: | |
""" | |
Retrieve relevant contexts using different strategies | |
Args: | |
query: User query | |
top_k: Number of contexts to retrieve | |
strategy: 'semantic', 'keyword', 'hybrid', 'domain_aware' | |
""" | |
if strategy == 'semantic': | |
return self._semantic_retrieval(query, top_k) | |
elif strategy == 'keyword': | |
return self._keyword_retrieval(query, top_k) | |
elif strategy == 'domain_aware': | |
return self._domain_aware_retrieval(query, top_k) | |
else: # hybrid | |
return self._hybrid_retrieval(query, top_k) | |
def _semantic_retrieval(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
"""Pure semantic similarity retrieval""" | |
if self.context_embeddings is None: | |
return [] | |
query_embedding = self.embedding_model.encode([query]) | |
similarities = cosine_similarity(query_embedding, self.context_embeddings)[0] | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
results = [] | |
for idx in top_indices: | |
if similarities[idx] > 0.2: # Minimum similarity threshold | |
results.append({ | |
'context': self.contexts[idx], | |
'similarity': float(similarities[idx]), | |
'method': 'semantic' | |
}) | |
return results | |
def _keyword_retrieval(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
"""Keyword-based retrieval""" | |
query_lower = query.lower() | |
candidate_indices = set() | |
# Find contexts containing query keywords | |
for keyword, indices in self.keyword_index.items(): | |
if keyword in query_lower: | |
candidate_indices.update(indices) | |
# Score candidates based on keyword frequency | |
scored_candidates = [] | |
for idx in candidate_indices: | |
score = self._calculate_keyword_score(query_lower, self.contexts[idx]) | |
scored_candidates.append((idx, score)) | |
# Sort by score and return top_k | |
scored_candidates.sort(key=lambda x: x[1], reverse=True) | |
results = [] | |
for idx, score in scored_candidates[:top_k]: | |
results.append({ | |
'context': self.contexts[idx], | |
'similarity': score, | |
'method': 'keyword' | |
}) | |
return results | |
def _domain_aware_retrieval(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
"""Domain-aware retrieval""" | |
# First, identify the most relevant domain | |
query_embedding = self.embedding_model.encode([query]) | |
domain_similarities = {} | |
for domain, domain_embedding in self.domain_embeddings.items(): | |
similarity = cosine_similarity( | |
query_embedding.reshape(1, -1), | |
domain_embedding.reshape(1, -1) | |
)[0][0] | |
domain_similarities[domain] = similarity | |
# Get top 2 domains | |
top_domains = sorted(domain_similarities.items(), key=lambda x: x[1], reverse=True)[:2] | |
# Filter contexts by top domains | |
domain_filtered_contexts = [] | |
for i, ctx in enumerate(self.contexts): | |
if ctx['domain'] in [d[0] for d in top_domains]: | |
domain_filtered_contexts.append((i, ctx)) | |
if not domain_filtered_contexts: | |
return self._semantic_retrieval(query, top_k) | |
# Perform semantic retrieval within filtered contexts | |
filtered_indices = [i for i, _ in domain_filtered_contexts] | |
filtered_embeddings = self.context_embeddings[filtered_indices] | |
similarities = cosine_similarity(query_embedding, filtered_embeddings)[0] | |
# Get top results | |
top_local_indices = np.argsort(similarities)[-top_k:][::-1] | |
results = [] | |
for local_idx in top_local_indices: | |
global_idx = filtered_indices[local_idx] | |
if similarities[local_idx] > 0.2: | |
results.append({ | |
'context': self.contexts[global_idx], | |
'similarity': float(similarities[local_idx]), | |
'method': 'domain_aware', | |
'domain': self.contexts[global_idx]['domain'] | |
}) | |
return results | |
def _hybrid_retrieval(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
"""Combine semantic and keyword retrieval""" | |
# Get results from both methods | |
semantic_results = self._semantic_retrieval(query, top_k) | |
keyword_results = self._keyword_retrieval(query, top_k) | |
# Combine and re-rank | |
combined_results = {} | |
# Add semantic results with weight | |
for result in semantic_results: | |
ctx_id = id(result['context']) | |
combined_results[ctx_id] = { | |
'context': result['context'], | |
'semantic_score': result['similarity'], | |
'keyword_score': 0.0, | |
'method': 'hybrid' | |
} | |
# Add keyword results | |
for result in keyword_results: | |
ctx_id = id(result['context']) | |
if ctx_id in combined_results: | |
combined_results[ctx_id]['keyword_score'] = result['similarity'] | |
else: | |
combined_results[ctx_id] = { | |
'context': result['context'], | |
'semantic_score': 0.0, | |
'keyword_score': result['similarity'], | |
'method': 'hybrid' | |
} | |
# Calculate combined score | |
final_results = [] | |
for ctx_id, result in combined_results.items(): | |
# Weighted combination: 70% semantic, 30% keyword | |
combined_score = (0.7 * result['semantic_score'] + | |
0.3 * result['keyword_score']) | |
final_results.append({ | |
'context': result['context'], | |
'similarity': combined_score, | |
'method': result['method'], | |
'semantic_score': result['semantic_score'], | |
'keyword_score': result['keyword_score'] | |
}) | |
# Sort by combined score | |
final_results.sort(key=lambda x: x['similarity'], reverse=True) | |
return final_results[:top_k] | |
def _calculate_keyword_score(self, query: str, context: Dict[str, Any]) -> float: | |
"""Calculate keyword-based similarity score""" | |
context_text = (context['context'] + " " + | |
" ".join([qa['question'] + " " + qa['answer'] | |
for qa in context['qas']])).lower() | |
query_words = set(query.split()) | |
context_words = set(context_text.split()) | |
# Jaccard similarity | |
intersection = len(query_words.intersection(context_words)) | |
union = len(query_words.union(context_words)) | |
if union == 0: | |
return 0.0 | |
return intersection / union | |
def find_exact_qa_match(self, query: str, threshold: float = 0.8) -> Dict[str, Any]: | |
"""Find exact Q&A matches for the query""" | |
if not self.question_embeddings.size: | |
return None | |
query_embedding = self.embedding_model.encode([query]) | |
similarities = cosine_similarity(query_embedding, self.question_embeddings)[0] | |
best_match_idx = np.argmax(similarities) | |
best_similarity = similarities[best_match_idx] | |
if best_similarity > threshold: | |
return { | |
'qa': self.qa_pairs[best_match_idx]['qa'], | |
'context': self.contexts[self.qa_pairs[best_match_idx]['context_idx']], | |
'similarity': float(best_similarity), | |
'method': 'exact_qa_match' | |
} | |
return None | |
def get_domain_statistics(self) -> Dict[str, int]: | |
"""Get statistics about domains in the dataset""" | |
domain_counts = defaultdict(int) | |
for ctx in self.contexts: | |
domain_counts[ctx['domain']] += 1 | |
return dict(domain_counts) | |
class AdvancedVATIKARetriever(VATIKARetriever): | |
""" | |
Extended retriever with additional features for better performance | |
""" | |
def __init__(self, embedding_model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): | |
super().__init__(embedding_model_name) | |
self.query_cache = {} | |
self.max_cache_size = 1000 | |
def retrieve_with_caching(self, query: str, top_k: int = 5, strategy: str = 'hybrid') -> List[Dict[str, Any]]: | |
"""Retrieve with caching for better performance""" | |
cache_key = f"{query}_{top_k}_{strategy}" | |
if cache_key in self.query_cache: | |
logger.info(f"Cache hit for query: {query[:50]}...") | |
return self.query_cache[cache_key] | |
results = self.retrieve_contexts(query, top_k, strategy) | |
# Cache management | |
if len(self.query_cache) >= self.max_cache_size: | |
# Remove oldest entry | |
oldest_key = next(iter(self.query_cache)) | |
del self.query_cache[oldest_key] | |
self.query_cache[cache_key] = results | |
return results | |
def retrieve_with_reranking(self, query: str, top_k: int = 5, rerank_top_k: int = 10) -> List[Dict[str, Any]]: | |
""" | |
Two-stage retrieval: first retrieve more candidates, then rerank | |
""" | |
# First stage: retrieve more candidates | |
candidates = self.retrieve_contexts(query, rerank_top_k, 'hybrid') | |
if len(candidates) <= top_k: | |
return candidates | |
# Second stage: rerank using more sophisticated scoring | |
reranked_candidates = [] | |
for candidate in candidates: | |
# Calculate additional features | |
domain_relevance = self._calculate_domain_relevance(query, candidate['context']['domain']) | |
qa_relevance = self._calculate_qa_relevance(query, candidate['context']['qas']) | |
# Combined score | |
final_score = (0.5 * candidate['similarity'] + | |
0.3 * domain_relevance + | |
0.2 * qa_relevance) | |
reranked_candidates.append({ | |
**candidate, | |
'final_score': final_score, | |
'domain_relevance': domain_relevance, | |
'qa_relevance': qa_relevance | |
}) | |
# Sort by final score | |
reranked_candidates.sort(key=lambda x: x['final_score'], reverse=True) | |
return reranked_candidates[:top_k] | |
def _calculate_domain_relevance(self, query: str, domain: str) -> float: | |
"""Calculate domain relevance score""" | |
domain_keywords = { | |
'temple': ['मंदिर', 'देवालय', 'temple', 'पूजा', 'दर्शन'], | |
'ghat': ['घाट', 'ghat', 'तट', 'गंगा'], | |
'aarti': ['आरती', 'aarti', 'पूजा', 'गंगा'], | |
'food': ['भोजन', 'खाना', 'food', 'खाने'], | |
'travel': ['यात्रा', 'travel', 'जाना', 'पहुंचना'], | |
'museum': ['संग्रहालय', 'museum', 'म्यूजियम'], | |
'ashram': ['आश्रम', 'ashram'], | |
'kund': ['कुंड', 'kund', 'तालाब'], | |
'cruise': ['क्रूज़', 'cruise', 'नाव'], | |
'toilet': ['शौचालय', 'toilet', 'टॉयलेट'] | |
} | |
query_lower = query.lower() | |
domain_lower = domain.lower() | |
# Direct domain match | |
if domain_lower in query_lower: | |
return 1.0 | |
# Keyword match | |
if domain_lower in domain_keywords: | |
keywords = domain_keywords[domain_lower] | |
for keyword in keywords: | |
if keyword in query_lower: | |
return 0.8 | |
return 0.1 | |
def _calculate_qa_relevance(self, query: str, qas: List[Dict[str, Any]]) -> float: | |
"""Calculate Q&A relevance score""" | |
if not qas: | |
return 0.0 | |
max_relevance = 0.0 | |
for qa in qas: | |
question_similarity = self._text_similarity(query, qa['question']) | |
max_relevance = max(max_relevance, question_similarity) | |
return max_relevance | |
def _text_similarity(self, text1: str, text2: str) -> float: | |
"""Calculate text similarity using embeddings""" | |
try: | |
embeddings = self.embedding_model.encode([text1, text2]) | |
similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0] | |
return float(similarity) | |
except: | |
return 0.0 |