|
import gradio as gr |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from transformers import ( |
|
RobertaTokenizerFast, RobertaForSequenceClassification, |
|
AutoTokenizer, AutoModelForSequenceClassification, |
|
M2M100ForConditionalGeneration, M2M100Tokenizer |
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/nllb-200-3.3B") |
|
translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/nllb-200-3.3B").to(device).eval() |
|
|
|
|
|
safety_model = RobertaForSequenceClassification.from_pretrained("Alifjo123/roberta-safety-model") |
|
safety_tokenizer = RobertaTokenizerFast.from_pretrained("Alifjo123/roberta-safety-model") |
|
safety_model.to(device).eval() |
|
|
|
|
|
dating_model = AutoModelForSequenceClassification.from_pretrained("Alifjo123/dating-sentiment-model") |
|
dating_tokenizer = AutoTokenizer.from_pretrained("Alifjo123/dating-sentiment-model") |
|
dating_model.to(device).eval() |
|
dating_labels = ["negative", "neutral", "positive"] |
|
|
|
|
|
def translate_to_english(text, src_lang="ind_Latn"): |
|
translator_tokenizer.src_lang = src_lang |
|
encoded = translator_tokenizer(text, return_tensors="pt", padding=True).to(device) |
|
generated = translator_model.generate(**encoded, forced_bos_token_id=translator_tokenizer.lang_code_to_id["eng_Latn"]) |
|
translated = translator_tokenizer.batch_decode(generated, skip_special_tokens=True) |
|
return translated[0] |
|
|
|
def classify_text_all(input_text): |
|
if not input_text.strip(): |
|
return "β οΈ Please enter a message.", "", plt.figure() |
|
|
|
|
|
translated_text = translate_to_english(input_text) |
|
|
|
|
|
safety_inputs = safety_tokenizer(translated_text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
safety_inputs = {k: v.to(device) for k, v in safety_inputs.items()} |
|
with torch.no_grad(): |
|
safety_outputs = safety_model(**safety_inputs) |
|
safety_probs = torch.nn.functional.softmax(safety_outputs.logits, dim=1) |
|
safety_label = torch.argmax(safety_probs).item() |
|
safety_conf = safety_probs[0][safety_label].item() |
|
safety_result = f"{'π« Unsafe' if safety_label == 1 else 'β
Safe'} ({safety_conf:.2%} confidence)" |
|
|
|
|
|
dating_inputs = dating_tokenizer(translated_text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
dating_inputs = {k: v.to(device) for k, v in dating_inputs.items()} |
|
with torch.no_grad(): |
|
dating_outputs = dating_model(**dating_inputs) |
|
dating_probs = torch.nn.functional.softmax(dating_outputs.logits, dim=1) |
|
dating_label = torch.argmax(dating_probs).item() |
|
dating_conf = dating_probs[0][dating_label].item() |
|
dating_result = f"{dating_labels[dating_label].capitalize()} ({dating_conf:.2%} confidence)" |
|
|
|
|
|
fig, axs = plt.subplots(1, 2, figsize=(10, 4)) |
|
axs[0].bar(["Safe", "Unsafe"], safety_probs[0].cpu().numpy(), color=["green", "red"]) |
|
axs[0].set_ylim([0, 1]) |
|
axs[0].set_title("Safety Confidence") |
|
axs[0].set_ylabel("Probability") |
|
|
|
axs[1].bar(dating_labels, dating_probs[0].cpu().numpy(), color=["red", "gray", "green"]) |
|
axs[1].set_ylim([0, 1]) |
|
axs[1].set_title("Dating Sentiment") |
|
plt.tight_layout() |
|
|
|
return safety_result, dating_result, fig |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_text_all, |
|
inputs=gr.Textbox(lines=5, label="π¬ Input a message (any language)"), |
|
outputs=[ |
|
gr.Textbox(label="π Safety Classification (Safe vs Unsafe)"), |
|
gr.Textbox(label="π Dating Sentiment (Negative, Neutral, Positive)"), |
|
gr.Plot(label="π Confidence Comparison") |
|
], |
|
title="π§ Multi-Model Analyzer (Multilingual Support)", |
|
description=( |
|
"This interface evaluates a single input using **two models**:\n\n" |
|
"1. **Safety Classifier** (`RoBERTa`) trained on contextual safety data.\n" |
|
"2. **Dating Sentiment Classifier** (`RoBERTa`) fine-tuned for romantic tone detection.\n\n" |
|
"πΆ Supports **multilingual inputs** using [facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B).\n" |
|
"Translate your message automatically into English before classification." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|