Alifjo123's picture
Update app.py
d981024 verified
raw
history blame
4.62 kB
import gradio as gr
import torch
import matplotlib.pyplot as plt
from transformers import (
RobertaTokenizerFast, RobertaForSequenceClassification,
AutoTokenizer, AutoModelForSequenceClassification,
M2M100ForConditionalGeneration, M2M100Tokenizer
)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Translation Model: NLLB (200+ languages to English) ---
translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/nllb-200-3.3B")
translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/nllb-200-3.3B").to(device).eval()
# --- Model 1: Safety Classification ---
safety_model = RobertaForSequenceClassification.from_pretrained("Alifjo123/roberta-safety-model")
safety_tokenizer = RobertaTokenizerFast.from_pretrained("Alifjo123/roberta-safety-model")
safety_model.to(device).eval()
# --- Model 2: Dating Sentiment (Custom RoBERTa) ---
dating_model = AutoModelForSequenceClassification.from_pretrained("Alifjo123/dating-sentiment-model")
dating_tokenizer = AutoTokenizer.from_pretrained("Alifjo123/dating-sentiment-model")
dating_model.to(device).eval()
dating_labels = ["negative", "neutral", "positive"]
# --- Translation Function ---
def translate_to_english(text, src_lang="ind_Latn"):
translator_tokenizer.src_lang = src_lang
encoded = translator_tokenizer(text, return_tensors="pt", padding=True).to(device)
generated = translator_model.generate(**encoded, forced_bos_token_id=translator_tokenizer.lang_code_to_id["eng_Latn"])
translated = translator_tokenizer.batch_decode(generated, skip_special_tokens=True)
return translated[0]
def classify_text_all(input_text):
if not input_text.strip():
return "⚠️ Please enter a message.", "", plt.figure()
# --- Step 1: Translate to English for classification ---
translated_text = translate_to_english(input_text)
# --- Safety Classification ---
safety_inputs = safety_tokenizer(translated_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
safety_inputs = {k: v.to(device) for k, v in safety_inputs.items()}
with torch.no_grad():
safety_outputs = safety_model(**safety_inputs)
safety_probs = torch.nn.functional.softmax(safety_outputs.logits, dim=1)
safety_label = torch.argmax(safety_probs).item()
safety_conf = safety_probs[0][safety_label].item()
safety_result = f"{'🚫 Unsafe' if safety_label == 1 else 'βœ… Safe'} ({safety_conf:.2%} confidence)"
# --- Dating Sentiment Classification ---
dating_inputs = dating_tokenizer(translated_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
dating_inputs = {k: v.to(device) for k, v in dating_inputs.items()}
with torch.no_grad():
dating_outputs = dating_model(**dating_inputs)
dating_probs = torch.nn.functional.softmax(dating_outputs.logits, dim=1)
dating_label = torch.argmax(dating_probs).item()
dating_conf = dating_probs[0][dating_label].item()
dating_result = f"{dating_labels[dating_label].capitalize()} ({dating_conf:.2%} confidence)"
# --- Confidence Bar Charts ---
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].bar(["Safe", "Unsafe"], safety_probs[0].cpu().numpy(), color=["green", "red"])
axs[0].set_ylim([0, 1])
axs[0].set_title("Safety Confidence")
axs[0].set_ylabel("Probability")
axs[1].bar(dating_labels, dating_probs[0].cpu().numpy(), color=["red", "gray", "green"])
axs[1].set_ylim([0, 1])
axs[1].set_title("Dating Sentiment")
plt.tight_layout()
return safety_result, dating_result, fig
# Gradio UI
demo = gr.Interface(
fn=classify_text_all,
inputs=gr.Textbox(lines=5, label="πŸ’¬ Input a message (any language)"),
outputs=[
gr.Textbox(label="πŸ” Safety Classification (Safe vs Unsafe)"),
gr.Textbox(label="πŸ’˜ Dating Sentiment (Negative, Neutral, Positive)"),
gr.Plot(label="πŸ“Š Confidence Comparison")
],
title="🧠 Multi-Model Analyzer (Multilingual Support)",
description=(
"This interface evaluates a single input using **two models**:\n\n"
"1. **Safety Classifier** (`RoBERTa`) trained on contextual safety data.\n"
"2. **Dating Sentiment Classifier** (`RoBERTa`) fine-tuned for romantic tone detection.\n\n"
"🈢 Supports **multilingual inputs** using [facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B).\n"
"Translate your message automatically into English before classification."
)
)
if __name__ == "__main__":
demo.launch()