Spaces:
Sleeping
Sleeping
import logging | |
from sentence_transformers import CrossEncoder | |
from app.models.document import Document | |
logger = logging.getLogger(__name__) | |
class CrossEncoderReranker: | |
"""Cross-encoder based reranking for improving retrieval quality""" | |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
logger.debug(f"Initializing CrossEncoderReranker with model: {model_name}") | |
self.model_name = model_name | |
logger.debug("Loading CrossEncoder model on CPU") | |
try: | |
self.model = CrossEncoder(model_name, device="cpu") | |
logger.info(f"Initialized reranker with model: {model_name}") | |
except Exception as e: | |
logger.error(f"Failed to initialize CrossEncoder model {model_name}: {e}") | |
raise | |
def rerank( | |
self, | |
query: str, | |
results: list[tuple[Document, float]], | |
top_k: int | None = None, | |
) -> list[tuple[Document, float]]: | |
""" | |
Rerank documents using cross-encoder model | |
Args: | |
query: The search query | |
results: List of (Document, score) tuples | |
top_k: Number of top results to return (None = all) | |
Returns: | |
Reranked list of (Document, score) tuples | |
""" | |
logger.debug( | |
f"Reranking initiated with query: '{query}', {len(results)} results, top_k={top_k}" | |
) | |
if not results: | |
logger.debug("No results to rerank, returning empty list") | |
return results | |
logger.debug("Creating query-document pairs for reranking") | |
pairs = [(query, doc.content) for doc, _ in results] | |
logger.debug(f"Created {len(pairs)} query-document pairs") | |
logger.debug("Computing rerank scores using CrossEncoder") | |
try: | |
rerank_scores = self.model.predict(pairs) | |
logger.debug(f"Generated {len(rerank_scores)} rerank scores") | |
except Exception as e: | |
logger.error(f"Error during reranking prediction: {e}") | |
raise | |
logger.debug("Combining documents with new rerank scores") | |
reranked_results = [] | |
for i, ((doc, original_score), rerank_score) in enumerate( | |
zip(results, rerank_scores, strict=False) | |
): | |
reranked_results.append((doc, float(rerank_score))) | |
logger.debug( | |
f"Document {i + 1}: original_score={original_score:.4f}, rerank_score={rerank_score:.4f}" | |
) | |
logger.debug("Sorting results by rerank scores (descending)") | |
reranked_results.sort(key=lambda x: x[1], reverse=True) | |
if logger.isEnabledFor(logging.DEBUG): | |
for i, (doc, score) in enumerate(reranked_results[:5]): | |
logger.debug( | |
f"Top {i + 1} after reranking: score={score:.4f}, source={doc.metadata.get('filename', 'unknown')}" | |
) | |
if top_k is not None: | |
logger.debug(f"Applying top_k filter: keeping top {top_k} results") | |
final_results = reranked_results[:top_k] | |
logger.info( | |
f"Reranking completed: returned {len(final_results)} of {len(results)} results (top_k={top_k})" | |
) | |
return final_results | |
logger.info( | |
f"Reranking completed: returned all {len(reranked_results)} reranked results" | |
) | |
return reranked_results | |