Spaces:
Running
Running
import logging | |
from app.core.config import Settings | |
from app.models.document import Document | |
from app.services.indexes import BM25Searcher, ColBERTSearcher | |
from app.services.reranking import CrossEncoderReranker | |
logger = logging.getLogger(__name__) | |
class HybridRetriever: | |
"""Hybrid retrieval combining BM25, Dense, and ColBERT""" | |
def __init__(self, settings: Settings): | |
logger.debug("Initializing HybridRetriever") | |
self.settings = settings | |
logger.debug("Creating BM25Searcher instance") | |
self.bm25_searcher = BM25Searcher() | |
self.colbert_searcher = None | |
self.reranker = None | |
if settings.use_colbert: | |
logger.debug( | |
f"ColBERT enabled, creating ColBERTSearcher with model: {settings.colbert_model}" | |
) | |
self.colbert_searcher = ColBERTSearcher(settings.colbert_model) | |
else: | |
logger.debug("ColBERT disabled") | |
if settings.use_reranking: | |
logger.debug( | |
f"Reranking enabled, creating CrossEncoderReranker with model: {settings.rerank_model}" | |
) | |
self.reranker = CrossEncoderReranker(settings.rerank_model) | |
else: | |
logger.debug("Reranking disabled") | |
retrieval_method = self.get_retrieval_method() | |
logger.info(f"HybridRetriever initialized with method: {retrieval_method}") | |
def build_index(self, documents: list[Document]) -> None: | |
"""Build all search indexes""" | |
logger.info(f"Building indexes for {len(documents)} documents") | |
if not self.settings.colbert_only: | |
logger.debug("Building BM25 index") | |
self.bm25_searcher.build_index(documents) | |
logger.debug("BM25 index build completed") | |
else: | |
logger.debug("Skipping BM25 index (ColBERT-only mode)") | |
if self.settings.use_colbert and self.colbert_searcher: | |
logger.debug("Building ColBERT index") | |
self.colbert_searcher.build_index(documents) | |
logger.debug("ColBERT index build completed") | |
else: | |
logger.debug("Skipping ColBERT index (not enabled)") | |
logger.info("All enabled indexes built successfully") | |
def save_index(self) -> None: | |
"""Save indexes to disk""" | |
logger.info("Saving indexes to disk") | |
if not self.settings.colbert_only: | |
logger.debug(f"Saving BM25 index to: {self.settings.index_path}") | |
self.bm25_searcher.save(self.settings.index_path) | |
logger.debug("BM25 index saved successfully") | |
else: | |
logger.debug("Skipping BM25 index save (ColBERT-only mode)") | |
if self.settings.use_colbert and self.colbert_searcher: | |
logger.debug(f"Saving ColBERT index to: {self.settings.colbert_index_path}") | |
self.colbert_searcher.save(self.settings.colbert_index_path) | |
logger.debug("ColBERT index saved successfully") | |
else: | |
logger.debug("Skipping ColBERT index save (not enabled)") | |
logger.info("All enabled indexes saved successfully") | |
def load_index(self) -> None: | |
"""Load indexes from disk""" | |
logger.info("Loading indexes from disk") | |
if not self.settings.colbert_only: | |
logger.debug(f"Loading BM25 index from: {self.settings.index_path}") | |
self.bm25_searcher.load(self.settings.index_path) | |
logger.debug("BM25 index loaded successfully") | |
else: | |
logger.debug("Skipping BM25 index load (ColBERT-only mode)") | |
if self.settings.use_colbert and self.colbert_searcher: | |
logger.debug( | |
f"Loading ColBERT index from: {self.settings.colbert_index_path}" | |
) | |
self.colbert_searcher.load(self.settings.colbert_index_path) | |
logger.debug("ColBERT index loaded successfully") | |
else: | |
logger.debug("Skipping ColBERT index load (not enabled)") | |
logger.info("All enabled indexes loaded successfully") | |
def retrieve(self, query: str, top_k: int) -> list[tuple[Document, float]]: | |
"""Retrieve documents using hybrid search""" | |
logger.debug(f"Hybrid retrieval initiated with query: '{query}', top_k={top_k}") | |
if self.settings.colbert_only and self.colbert_searcher: | |
logger.debug("Using ColBERT-only retrieval") | |
results = self.colbert_searcher.search(query, top_k) | |
logger.info(f"ColBERT-only retrieval completed: {len(results)} results") | |
return results | |
logger.debug("Using hybrid retrieval approach") | |
all_results = [] | |
if self.settings.use_colbert and self.colbert_searcher: | |
logger.debug( | |
f"Performing ColBERT search with top_k={self.settings.colbert_top_k}, weight={self.settings.colbert_weight}" | |
) | |
colbert_results = self.colbert_searcher.search( | |
query, self.settings.colbert_top_k | |
) | |
logger.debug(f"ColBERT search returned {len(colbert_results)} results") | |
all_results.append((colbert_results, self.settings.colbert_weight)) | |
logger.debug( | |
f"Performing BM25 search with top_k={self.settings.bm25_top_k}, weight={self.settings.bm25_weight}" | |
) | |
bm25_results = self.bm25_searcher.search(query, self.settings.bm25_top_k) | |
logger.debug(f"BM25 search returned {len(bm25_results)} results") | |
all_results.append((bm25_results, self.settings.bm25_weight)) | |
logger.debug(f"Combining results from {len(all_results)} retrievers") | |
combined = self._combine_results(all_results) | |
logger.debug(f"Combined results: {len(combined)} unique documents") | |
if self.settings.use_reranking and self.reranker: | |
logger.debug("Applying reranking to combined results") | |
combined = self.reranker.rerank(query, combined) | |
logger.debug("Reranking completed") | |
else: | |
logger.debug("Skipping reranking (not enabled)") | |
final_results = combined[:top_k] | |
logger.info( | |
f"Hybrid retrieval completed: {len(final_results)} final results returned" | |
) | |
return final_results | |
def _combine_results( | |
self, results_with_weights: list[tuple[list[tuple[Document, float]], float]] | |
) -> list[tuple[Document, float]]: | |
"""Combine results from multiple retrievers with weighted scores""" | |
logger.debug(f"Combining results from {len(results_with_weights)} retrievers") | |
doc_scores = {} | |
total_docs_processed = 0 | |
for i, (results, weight) in enumerate(results_with_weights): | |
logger.debug( | |
f"Processing retriever {i + 1}: {len(results)} results, weight={weight}" | |
) | |
if not results: | |
logger.debug(f"Retriever {i + 1} returned no results, skipping") | |
continue | |
scores = [score for _, score in results] | |
if not scores: | |
logger.debug(f"Retriever {i + 1} has no valid scores, skipping") | |
continue | |
min_score = min(scores) | |
max_score = max(scores) | |
score_range = max_score - min_score if max_score != min_score else 1 | |
logger.debug( | |
f"Retriever {i + 1} score range: {min_score:.4f} to {max_score:.4f} (range: {score_range:.4f})" | |
) | |
new_docs = 0 | |
updated_docs = 0 | |
for doc, score in results: | |
normalized_score = (score - min_score) / score_range | |
weighted_score = normalized_score * weight | |
doc_id = id(doc) | |
if doc_id in doc_scores: | |
current_doc, current_score = doc_scores[doc_id] | |
doc_scores[doc_id] = ( | |
current_doc, | |
current_score + weighted_score, | |
) | |
updated_docs += 1 | |
logger.debug( | |
f"Updated existing doc: original_score={score:.4f}, normalized={normalized_score:.4f}, weighted={weighted_score:.4f}" | |
) | |
else: | |
doc_scores[doc_id] = (doc, weighted_score) | |
new_docs += 1 | |
logger.debug( | |
f"Added new doc: original_score={score:.4f}, normalized={normalized_score:.4f}, weighted={weighted_score:.4f}" | |
) | |
total_docs_processed += len(results) | |
logger.debug( | |
f"Retriever {i + 1} processing completed: {new_docs} new docs, {updated_docs} updated docs" | |
) | |
logger.debug( | |
f"Score combination completed: {len(doc_scores)} unique documents from {total_docs_processed} total results" | |
) | |
logger.debug("Sorting combined results by final scores") | |
combined_results = list(doc_scores.values()) | |
combined_results.sort(key=lambda x: x[1], reverse=True) | |
if logger.isEnabledFor(logging.DEBUG): | |
for i, (doc, score) in enumerate(combined_results[:5]): | |
source = doc.metadata.get("filename", "unknown") | |
logger.debug( | |
f"Top {i + 1} combined result: score={score:.4f}, source={source}" | |
) | |
logger.debug( | |
f"Result combination completed: {len(combined_results)} final combined results" | |
) | |
return combined_results | |
def get_retrieval_method(self) -> str: | |
"""Get description of retrieval method being used""" | |
if self.settings.colbert_only: | |
method = "ColBERT-only" | |
elif self.settings.use_colbert: | |
method = "Hybrid (BM25 + ColBERT)" | |
else: | |
method = "BM25-only" | |
logger.debug(f"Retrieval method: {method}") | |
return method | |