|
import time |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer, SparseEncoder |
|
from sentence_transformers.quantization import quantize_embeddings |
|
import faiss |
|
from usearch.index import Index |
|
import torch |
|
import numpy as np |
|
from collections import defaultdict |
|
from scipy import stats |
|
|
|
|
|
|
|
|
|
wikipedia_dataset = load_dataset("CATIE-AQ/wikipedia_fr_2022_250K", split="train", num_proc=4).select_columns(["title", "text", "wiki_id", "sparse_emb"]) |
|
|
|
|
|
|
|
def add_link(example): |
|
example["title"] = '['+example["title"]+']('+'https://fr.wikipedia.org/wiki?curid='+str(example["wiki_id"])+')' |
|
return example |
|
wikipedia_dataset = wikipedia_dataset.map(add_link) |
|
|
|
|
|
|
|
int8_view = Index.restore("wikipedia_fr_2022_250K_int8_usearch.index", view=True) |
|
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_faiss.index") |
|
binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_ivf_faiss.index") |
|
|
|
|
|
|
|
dense_model = SentenceTransformer("OrdalieTech/Solon-embeddings-large-0.1") |
|
sparse_model = SparseEncoder("CATIE-AQ/SPLADE_camembert-base_STS") |
|
|
|
|
|
|
|
def reciprocal_rank_fusion(dense_results, sparse_results, k=20): |
|
""" |
|
Perform Reciprocal Rank Fusion to combine dense and sparse retrieval results |
|
|
|
Args: |
|
dense_results: List of (doc_id, score) from dense retrieval |
|
sparse_results: List of (doc_id, score) from sparse retrieval |
|
k: RRF parameter (default 20) |
|
|
|
Returns: |
|
List of (doc_id, rrf_score) sorted by RRF score |
|
""" |
|
rrf_scores = defaultdict(float) |
|
|
|
|
|
for rank, (doc_id, _) in enumerate(dense_results, 1): |
|
rrf_scores[doc_id] += 1 / (k + rank) |
|
|
|
|
|
for rank, (doc_id, _) in enumerate(sparse_results, 1): |
|
rrf_scores[doc_id] += 1 / (k + rank) |
|
|
|
|
|
sorted_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) |
|
return sorted_results |
|
|
|
|
|
def normalized_score_fusion(dense_results, sparse_results, dense_weight=0.5, sparse_weight=0.5): |
|
""" |
|
Perform Normalized Score Fusion (NSF) with z-score normalization |
|
|
|
Args: |
|
dense_results: List of (doc_id, score) from dense retrieval |
|
sparse_results: List of (doc_id, score) from sparse retrieval |
|
dense_weight: Weight for dense scores (default 0.5) |
|
sparse_weight: Weight for sparse scores (default 0.5) |
|
|
|
Returns: |
|
List of (doc_id, normalized_score) sorted by normalized score |
|
""" |
|
|
|
dense_scores = np.array([score for _, score in dense_results]) |
|
sparse_scores = np.array([score for _, score in sparse_results]) |
|
|
|
|
|
if len(dense_scores) > 1 and np.std(dense_scores) > 1e-10: |
|
dense_scores_norm = stats.zscore(dense_scores) |
|
else: |
|
dense_scores_norm = np.zeros_like(dense_scores) |
|
|
|
if len(sparse_scores) > 1 and np.std(sparse_scores) > 1e-10: |
|
sparse_scores_norm = stats.zscore(sparse_scores) |
|
else: |
|
sparse_scores_norm = np.zeros_like(sparse_scores) |
|
|
|
|
|
dense_norm_dict = {doc_id: score for (doc_id, _), score in zip(dense_results, dense_scores_norm)} |
|
sparse_norm_dict = {doc_id: score for (doc_id, _), score in zip(sparse_results, sparse_scores_norm)} |
|
|
|
|
|
all_doc_ids = set() |
|
all_doc_ids.update(doc_id for doc_id, _ in dense_results) |
|
all_doc_ids.update(doc_id for doc_id, _ in sparse_results) |
|
|
|
|
|
nsf_scores = {} |
|
for doc_id in all_doc_ids: |
|
dense_norm_score = dense_norm_dict.get(doc_id, 0.0) |
|
sparse_norm_score = sparse_norm_dict.get(doc_id, 0.0) |
|
|
|
|
|
nsf_scores[doc_id] = (dense_weight * dense_norm_score + |
|
sparse_weight * sparse_norm_score) |
|
|
|
|
|
sorted_results = sorted(nsf_scores.items(), key=lambda x: x[1], reverse=True) |
|
return sorted_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
sparse_index = {} |
|
for i, sparse_emd in enumerate(wikipedia_dataset["sparse_emb"]): |
|
|
|
sparse_emd = np.array(sparse_emd) |
|
|
|
|
|
non_zero_indices = np.nonzero(sparse_emd)[0] |
|
non_zero_values = sparse_emd[non_zero_indices] |
|
sparse_index[i] = (non_zero_indices, non_zero_values) |
|
|
|
|
|
def sparse_search(query, top_k=20): |
|
""" |
|
Perform sparse retrieval using SPLADE representations with a dictionary-based sparse index. |
|
""" |
|
|
|
query_sparse_vector = sparse_model.encode(query, convert_to_numpy=False) |
|
|
|
|
|
query_sparse_vector = query_sparse_vector.to_dense().numpy() |
|
|
|
|
|
query_non_zero_indices = np.nonzero(query_sparse_vector)[0] |
|
query_non_zero_values = query_sparse_vector[query_non_zero_indices] |
|
|
|
|
|
scores = defaultdict(float) |
|
for doc_id, (doc_indices, doc_values) in sparse_index.items(): |
|
|
|
common_indices = np.intersect1d(query_non_zero_indices, doc_indices) |
|
|
|
for idx in common_indices: |
|
query_val = query_non_zero_values[np.where(query_non_zero_indices == idx)[0][0]] |
|
doc_val = doc_values[np.where(doc_indices == idx)[0][0]] |
|
scores[doc_id] += query_val * doc_val |
|
|
|
|
|
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
|
return sorted_scores[:top_k] |
|
|
|
|
|
|
|
def search(query, top_k: int = 20, rescore_multiplier: int = 1, use_approx: bool = False, |
|
fusion_method: str = "rrf", rrf_k: int = 20, dense_weight: float = 0.5, sparse_weight: float = 0.5): |
|
total_start_time = time.time() |
|
|
|
|
|
start_time = time.time() |
|
query_embedding = dense_model.encode(query, prompt="query: ") |
|
embed_time = time.time() - start_time |
|
|
|
start_time = time.time() |
|
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") |
|
quantize_time = time.time() - start_time |
|
|
|
index = binary_ivf if use_approx else binary_index |
|
start_time = time.time() |
|
_scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier * 2) |
|
binary_ids = binary_ids[0] |
|
dense_search_time = time.time() - start_time |
|
|
|
start_time = time.time() |
|
int8_embeddings = int8_view[binary_ids].astype(int) |
|
load_time = time.time() - start_time |
|
|
|
start_time = time.time() |
|
scores = query_embedding @ int8_embeddings.T |
|
rescore_time = time.time() - start_time |
|
|
|
|
|
dense_results = [(binary_ids[i], scores[i]) for i in range(len(binary_ids))] |
|
dense_results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
timing_info = { |
|
"Temps pour créer l'embedding de la requête (dense)": f"{embed_time:.4f} s", |
|
"Temps pour la quantification": f"{quantize_time:.4f} s", |
|
"Temps pour effectuer la recherche dense": f"{dense_search_time:.4f} s", |
|
"Temps de chargement": f"{load_time:.4f} s", |
|
"Temps de rescorage": f"{rescore_time:.4f} s", |
|
} |
|
|
|
|
|
if fusion_method != "dense_only": |
|
|
|
start_time = time.time() |
|
sparse_results = sparse_search(query, top_k * rescore_multiplier * 2) |
|
sparse_search_time = time.time() - start_time |
|
|
|
|
|
start_time = time.time() |
|
if fusion_method == "rrf": |
|
fusion_results = reciprocal_rank_fusion(dense_results, sparse_results, k=rrf_k) |
|
fusion_method_name = f"RRF (k={rrf_k})" |
|
elif fusion_method == "nsf": |
|
fusion_results = normalized_score_fusion(dense_results, sparse_results, |
|
dense_weight=dense_weight, sparse_weight=sparse_weight) |
|
fusion_method_name = f"NSF Z-score (dense={dense_weight:.1f}, sparse={sparse_weight:.1f})" |
|
else: |
|
fusion_results = dense_results |
|
fusion_method_name = "Dense uniquement (fallback)" |
|
|
|
fusion_time = time.time() - start_time |
|
|
|
|
|
final_results = fusion_results[:top_k * rescore_multiplier] |
|
final_doc_ids = [doc_id for doc_id, _ in final_results] |
|
final_scores = [score for _, score in final_results] |
|
|
|
timing_info.update({ |
|
|
|
|
|
"Temps pour la recherche sparse": f"{sparse_search_time:.4f} s", |
|
"Temps pour la fusion": f"{fusion_time:.4f} s", |
|
}) |
|
timing_info["Méthode de fusion utilisée"] = fusion_method_name |
|
else: |
|
|
|
final_doc_ids = [doc_id for doc_id, _ in dense_results[:top_k * rescore_multiplier]] |
|
final_scores = [score for _, score in dense_results[:top_k * rescore_multiplier]] |
|
timing_info["Méthode de fusion utilisée"] = "Dense uniquement" |
|
|
|
|
|
start_time = time.time() |
|
try: |
|
top_k_titles, top_k_texts = zip(*[(wikipedia_dataset[int(idx)]["title"], wikipedia_dataset[int(idx)]["text"]) |
|
for idx in final_doc_ids[:top_k]]) |
|
|
|
|
|
df = pd.DataFrame({ |
|
"Score_paragraphe": [round(float(score), 4) for score in final_scores[:top_k]], |
|
"Titre": top_k_titles, |
|
"Texte": top_k_texts |
|
}) |
|
|
|
|
|
score_sum = df.groupby('Titre')['Score_paragraphe'].sum().reset_index() |
|
df = pd.merge(df, score_sum, on='Titre', how='left') |
|
df.rename(columns={'Score_paragraphe_y': 'Score_article', 'Score_paragraphe_x': 'Score_paragraphe'}, inplace=True) |
|
df = df[["Score_article", "Score_paragraphe", "Titre", "Texte"]] |
|
df = df.sort_values('Score_article', ascending=False) |
|
|
|
except Exception as e: |
|
print(f"Error creating results DataFrame: {e}") |
|
df = pd.DataFrame({"Error": [f"No results found or error processing results: {e}"]}) |
|
|
|
sort_time = time.time() - start_time |
|
total_time = time.time() - total_start_time |
|
|
|
timing_info.update({ |
|
"Temps pour afficher les résultats": f"{sort_time:.4f} s", |
|
"Temps total": f"{total_time:.4f} s", |
|
}) |
|
|
|
return df, timing_info |
|
|
|
|
|
with gr.Blocks(title="Requêter Wikipedia avec Fusion Hybride 🔍") as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
## Requêter Wikipedia en temps réel 🔍 |
|
Ce démonstrateur permet de requêter un corpus composé des 250K paragraphes les plus consultés du Wikipédia francophone. |
|
Les résultats sont renvoyés en temps réel via un pipeline tournant sur un CPU 🚀 |
|
Nous nous sommes grandement inspirés du Space [quantized-retrieval](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval) conçu par [Tom Aarsen](https://huggingface.co/tomaarsen) 🤗 |
|
Si vous voulez en savoir plus sur le processus complet derrière ce démonstrateur, n'hésitez pas à déplier les liens ci-dessous. |
|
|
|
<details><summary>1. Détails sur les données</summary> |
|
Le corpus utilisé correspond au 250 000 premières lignes du jeu de données <a href="https://hf.co/datasets/Cohere/wikipedia-22-12-fr-embeddings"><i>wikipedia-22-12-fr-embeddings</i></a> mis en ligne par Cohere. |
|
Comme son nom l'indique il s'agit d'un jeu de données datant de décembre 2022. Cette information est à prendre en compte lorsque vous effectuez votre requête. |
|
De même il s'agit ici d'un sous-ensemble du jeu de données total, à savoir les 250 000 paragraphes les plus consultés à cette date-là. |
|
Ainsi, si vous effectuez une recherche pointue sur un sujet peu consulté, ce démonstrateur ne reverra probablement rien de pertinent. |
|
A noter également que Cohere a effectué un prétraitement sur les données ce qui a conduit à la suppression de dates par exemple. |
|
Ce jeu de données n'est donc pas optimal. L'idée était de pouvoir proposer quelque chose en peu de temps. |
|
Dans un deuxième temps, ce démonstrateur sera étendu à l'ensemble du jeu de données <i>wikipedia-22-12-fr-embeddings</i> (soit 13M de paragraphes). |
|
Il n'est pas exclus d'ensuite utiliser une version plus récente de Wikipedia (on peut penser par exemple à <a href="https://hf.co/datasets/wikimedia/wikipedia"><i>wikimedia/wikipedia</i></a> |
|
</details> |
|
|
|
<details><summary>2. Détails le pipeline</summary> |
|
A écrire quand ça sera terminé. |
|
</details> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=75): |
|
query = gr.Textbox( |
|
label="Requêter le Wikipédia francophone", |
|
placeholder="Saisissez une requête pour rechercher des textes pertinents dans Wikipédia.", |
|
) |
|
with gr.Column(scale=25): |
|
use_approx = gr.Radio( |
|
choices=[("Exacte", False), ("Approximative", True)], |
|
value=True, |
|
label="Type de recherche dense", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
top_k = gr.Slider( |
|
minimum=3, |
|
maximum=40, |
|
step=1, |
|
value=15, |
|
label="Nombre de documents à retourner", |
|
) |
|
with gr.Column(scale=2): |
|
rescore_multiplier = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=1, |
|
label="Coefficient de rescorage", |
|
info="Augmente le nombre de candidats avant fusion", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
fusion_method = gr.Radio( |
|
choices=[ |
|
("Dense uniquement", "dense_only"), |
|
("Fusion RRF", "rrf"), |
|
("Fusion NSF Z-score", "nsf") |
|
], |
|
value="rrf", |
|
label="Méthode de fusion", |
|
info="Choisissez comment combiner les résultats des modèles dense et sparse" |
|
) |
|
with gr.Column(scale=2): |
|
rrf_k = gr.Slider( |
|
minimum=1, |
|
maximum=200, |
|
step=1, |
|
value=60, |
|
label="Paramètre k pour RRF", |
|
info="Plus k est élevé, moins les rangs ont d'importance", |
|
visible=True |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
dense_weight = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.5, |
|
label="Poids Dense (NSF)", |
|
info="Importance des résultats du modèle dense", |
|
visible=False |
|
) |
|
with gr.Column(scale=2): |
|
sparse_weight = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.5, |
|
label="Poids Sparse (NSF)", |
|
info="Importance des résultats du modèle sparse", |
|
visible=False |
|
) |
|
|
|
|
|
fusion_method.change( |
|
fn=lambda method: ( |
|
gr.update(visible=(method == "rrf")), |
|
gr.update(visible=(method == "nsf")), |
|
gr.update(visible=(method == "nsf")) |
|
), |
|
inputs=[fusion_method], |
|
outputs=[rrf_k, dense_weight, sparse_weight] |
|
) |
|
|
|
search_button = gr.Button(value="Rechercher", variant="primary") |
|
|
|
output = gr.Dataframe(headers=["Score_article", "Score_paragraphe", "Titre", "Texte"], datatype="markdown") |
|
json = gr.JSON(label="Informations de performance") |
|
|
|
query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx, fusion_method, rrf_k, dense_weight, sparse_weight], outputs=[output, json]) |
|
search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx, fusion_method, rrf_k, dense_weight, sparse_weight], outputs=[output, json]) |
|
|
|
demo.queue() |
|
demo.launch() |