aivre / app /services /reranking.py
Vedang Barhate
chore: copied from assist repo
cfc8e23
raw
history blame
3.42 kB
import logging
from sentence_transformers import CrossEncoder
from app.models.document import Document
logger = logging.getLogger(__name__)
class CrossEncoderReranker:
"""Cross-encoder based reranking for improving retrieval quality"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
logger.debug(f"Initializing CrossEncoderReranker with model: {model_name}")
self.model_name = model_name
logger.debug("Loading CrossEncoder model on CPU")
try:
self.model = CrossEncoder(model_name, device="cpu")
logger.info(f"Initialized reranker with model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize CrossEncoder model {model_name}: {e}")
raise
def rerank(
self,
query: str,
results: list[tuple[Document, float]],
top_k: int | None = None,
) -> list[tuple[Document, float]]:
"""
Rerank documents using cross-encoder model
Args:
query: The search query
results: List of (Document, score) tuples
top_k: Number of top results to return (None = all)
Returns:
Reranked list of (Document, score) tuples
"""
logger.debug(
f"Reranking initiated with query: '{query}', {len(results)} results, top_k={top_k}"
)
if not results:
logger.debug("No results to rerank, returning empty list")
return results
logger.debug("Creating query-document pairs for reranking")
pairs = [(query, doc.content) for doc, _ in results]
logger.debug(f"Created {len(pairs)} query-document pairs")
logger.debug("Computing rerank scores using CrossEncoder")
try:
rerank_scores = self.model.predict(pairs)
logger.debug(f"Generated {len(rerank_scores)} rerank scores")
except Exception as e:
logger.error(f"Error during reranking prediction: {e}")
raise
logger.debug("Combining documents with new rerank scores")
reranked_results = []
for i, ((doc, original_score), rerank_score) in enumerate(
zip(results, rerank_scores, strict=False)
):
reranked_results.append((doc, float(rerank_score)))
logger.debug(
f"Document {i + 1}: original_score={original_score:.4f}, rerank_score={rerank_score:.4f}"
)
logger.debug("Sorting results by rerank scores (descending)")
reranked_results.sort(key=lambda x: x[1], reverse=True)
if logger.isEnabledFor(logging.DEBUG):
for i, (doc, score) in enumerate(reranked_results[:5]):
logger.debug(
f"Top {i + 1} after reranking: score={score:.4f}, source={doc.metadata.get('filename', 'unknown')}"
)
if top_k is not None:
logger.debug(f"Applying top_k filter: keeping top {top_k} results")
final_results = reranked_results[:top_k]
logger.info(
f"Reranking completed: returned {len(final_results)} of {len(results)} results (top_k={top_k})"
)
return final_results
logger.info(
f"Reranking completed: returned all {len(reranked_results)} reranked results"
)
return reranked_results