aivre / app /services /retrieval.py
Vedang Barhate
chore: copied from assist repo
cfc8e23
raw
history blame
9.9 kB
import logging
from app.core.config import Settings
from app.models.document import Document
from app.services.indexes import BM25Searcher, ColBERTSearcher
from app.services.reranking import CrossEncoderReranker
logger = logging.getLogger(__name__)
class HybridRetriever:
"""Hybrid retrieval combining BM25, Dense, and ColBERT"""
def __init__(self, settings: Settings):
logger.debug("Initializing HybridRetriever")
self.settings = settings
logger.debug("Creating BM25Searcher instance")
self.bm25_searcher = BM25Searcher()
self.colbert_searcher = None
self.reranker = None
if settings.use_colbert:
logger.debug(
f"ColBERT enabled, creating ColBERTSearcher with model: {settings.colbert_model}"
)
self.colbert_searcher = ColBERTSearcher(settings.colbert_model)
else:
logger.debug("ColBERT disabled")
if settings.use_reranking:
logger.debug(
f"Reranking enabled, creating CrossEncoderReranker with model: {settings.rerank_model}"
)
self.reranker = CrossEncoderReranker(settings.rerank_model)
else:
logger.debug("Reranking disabled")
retrieval_method = self.get_retrieval_method()
logger.info(f"HybridRetriever initialized with method: {retrieval_method}")
def build_index(self, documents: list[Document]) -> None:
"""Build all search indexes"""
logger.info(f"Building indexes for {len(documents)} documents")
if not self.settings.colbert_only:
logger.debug("Building BM25 index")
self.bm25_searcher.build_index(documents)
logger.debug("BM25 index build completed")
else:
logger.debug("Skipping BM25 index (ColBERT-only mode)")
if self.settings.use_colbert and self.colbert_searcher:
logger.debug("Building ColBERT index")
self.colbert_searcher.build_index(documents)
logger.debug("ColBERT index build completed")
else:
logger.debug("Skipping ColBERT index (not enabled)")
logger.info("All enabled indexes built successfully")
def save_index(self) -> None:
"""Save indexes to disk"""
logger.info("Saving indexes to disk")
if not self.settings.colbert_only:
logger.debug(f"Saving BM25 index to: {self.settings.index_path}")
self.bm25_searcher.save(self.settings.index_path)
logger.debug("BM25 index saved successfully")
else:
logger.debug("Skipping BM25 index save (ColBERT-only mode)")
if self.settings.use_colbert and self.colbert_searcher:
logger.debug(f"Saving ColBERT index to: {self.settings.colbert_index_path}")
self.colbert_searcher.save(self.settings.colbert_index_path)
logger.debug("ColBERT index saved successfully")
else:
logger.debug("Skipping ColBERT index save (not enabled)")
logger.info("All enabled indexes saved successfully")
def load_index(self) -> None:
"""Load indexes from disk"""
logger.info("Loading indexes from disk")
if not self.settings.colbert_only:
logger.debug(f"Loading BM25 index from: {self.settings.index_path}")
self.bm25_searcher.load(self.settings.index_path)
logger.debug("BM25 index loaded successfully")
else:
logger.debug("Skipping BM25 index load (ColBERT-only mode)")
if self.settings.use_colbert and self.colbert_searcher:
logger.debug(
f"Loading ColBERT index from: {self.settings.colbert_index_path}"
)
self.colbert_searcher.load(self.settings.colbert_index_path)
logger.debug("ColBERT index loaded successfully")
else:
logger.debug("Skipping ColBERT index load (not enabled)")
logger.info("All enabled indexes loaded successfully")
def retrieve(self, query: str, top_k: int) -> list[tuple[Document, float]]:
"""Retrieve documents using hybrid search"""
logger.debug(f"Hybrid retrieval initiated with query: '{query}', top_k={top_k}")
if self.settings.colbert_only and self.colbert_searcher:
logger.debug("Using ColBERT-only retrieval")
results = self.colbert_searcher.search(query, top_k)
logger.info(f"ColBERT-only retrieval completed: {len(results)} results")
return results
logger.debug("Using hybrid retrieval approach")
all_results = []
if self.settings.use_colbert and self.colbert_searcher:
logger.debug(
f"Performing ColBERT search with top_k={self.settings.colbert_top_k}, weight={self.settings.colbert_weight}"
)
colbert_results = self.colbert_searcher.search(
query, self.settings.colbert_top_k
)
logger.debug(f"ColBERT search returned {len(colbert_results)} results")
all_results.append((colbert_results, self.settings.colbert_weight))
logger.debug(
f"Performing BM25 search with top_k={self.settings.bm25_top_k}, weight={self.settings.bm25_weight}"
)
bm25_results = self.bm25_searcher.search(query, self.settings.bm25_top_k)
logger.debug(f"BM25 search returned {len(bm25_results)} results")
all_results.append((bm25_results, self.settings.bm25_weight))
logger.debug(f"Combining results from {len(all_results)} retrievers")
combined = self._combine_results(all_results)
logger.debug(f"Combined results: {len(combined)} unique documents")
if self.settings.use_reranking and self.reranker:
logger.debug("Applying reranking to combined results")
combined = self.reranker.rerank(query, combined)
logger.debug("Reranking completed")
else:
logger.debug("Skipping reranking (not enabled)")
final_results = combined[:top_k]
logger.info(
f"Hybrid retrieval completed: {len(final_results)} final results returned"
)
return final_results
def _combine_results(
self, results_with_weights: list[tuple[list[tuple[Document, float]], float]]
) -> list[tuple[Document, float]]:
"""Combine results from multiple retrievers with weighted scores"""
logger.debug(f"Combining results from {len(results_with_weights)} retrievers")
doc_scores = {}
total_docs_processed = 0
for i, (results, weight) in enumerate(results_with_weights):
logger.debug(
f"Processing retriever {i + 1}: {len(results)} results, weight={weight}"
)
if not results:
logger.debug(f"Retriever {i + 1} returned no results, skipping")
continue
scores = [score for _, score in results]
if not scores:
logger.debug(f"Retriever {i + 1} has no valid scores, skipping")
continue
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score if max_score != min_score else 1
logger.debug(
f"Retriever {i + 1} score range: {min_score:.4f} to {max_score:.4f} (range: {score_range:.4f})"
)
new_docs = 0
updated_docs = 0
for doc, score in results:
normalized_score = (score - min_score) / score_range
weighted_score = normalized_score * weight
doc_id = id(doc)
if doc_id in doc_scores:
current_doc, current_score = doc_scores[doc_id]
doc_scores[doc_id] = (
current_doc,
current_score + weighted_score,
)
updated_docs += 1
logger.debug(
f"Updated existing doc: original_score={score:.4f}, normalized={normalized_score:.4f}, weighted={weighted_score:.4f}"
)
else:
doc_scores[doc_id] = (doc, weighted_score)
new_docs += 1
logger.debug(
f"Added new doc: original_score={score:.4f}, normalized={normalized_score:.4f}, weighted={weighted_score:.4f}"
)
total_docs_processed += len(results)
logger.debug(
f"Retriever {i + 1} processing completed: {new_docs} new docs, {updated_docs} updated docs"
)
logger.debug(
f"Score combination completed: {len(doc_scores)} unique documents from {total_docs_processed} total results"
)
logger.debug("Sorting combined results by final scores")
combined_results = list(doc_scores.values())
combined_results.sort(key=lambda x: x[1], reverse=True)
if logger.isEnabledFor(logging.DEBUG):
for i, (doc, score) in enumerate(combined_results[:5]):
source = doc.metadata.get("filename", "unknown")
logger.debug(
f"Top {i + 1} combined result: score={score:.4f}, source={source}"
)
logger.debug(
f"Result combination completed: {len(combined_results)} final combined results"
)
return combined_results
def get_retrieval_method(self) -> str:
"""Get description of retrieval method being used"""
if self.settings.colbert_only:
method = "ColBERT-only"
elif self.settings.use_colbert:
method = "Hybrid (BM25 + ColBERT)"
else:
method = "BM25-only"
logger.debug(f"Retrieval method: {method}")
return method