import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, RobertaTokenizerFast, RobertaForSequenceClassification ) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Safety Classification Model --- safety_model = RobertaForSequenceClassification.from_pretrained("Alifjo123/roberta-safety-model") safety_tokenizer = RobertaTokenizerFast.from_pretrained("Alifjo123/roberta-safety-model") safety_model.to(device).eval() # --- Dating Sentiment Classification Model --- dating_model = AutoModelForSequenceClassification.from_pretrained("Alifjo123/dating-sentiment-model") dating_tokenizer = AutoTokenizer.from_pretrained("Alifjo123/dating-sentiment-model") dating_model.to(device).eval() dating_labels = ["negative", "neutral", "positive"] # --- Classifier function --- def classify_text_all(msg1, msg2, msg3, msg4, msg5, msg6, use_context): results = [] messages = [msg1, msg2, msg3, msg4, msg5, msg6] context_window = 5 sep_token = safety_tokenizer.sep_token for idx in range(len(messages)): msg = messages[idx].strip() if not msg: results.append((f"Message {idx+1}: ⚠️ Empty input.", "", "")) continue # === Build context from previous messages === if use_context: previous = [m.strip() for m in messages[max(0, idx - context_window):idx] if m.strip()] context = sep_token.join(previous) full_msg = context + sep_token + msg if context else msg else: full_msg = msg # === Safety prediction === safety_inputs = safety_tokenizer(full_msg, return_tensors="pt", truncation=True, padding=True, max_length=512) safety_inputs = {k: v.to(device) for k, v in safety_inputs.items()} with torch.no_grad(): safety_outputs = safety_model(**safety_inputs) safety_probs = torch.nn.functional.softmax(safety_outputs.logits, dim=1) safety_label = torch.argmax(safety_probs).item() safety_conf = safety_probs[0][safety_label].item() safety_result = f"{'🚫 Unsafe' if safety_label == 1 else '✅ Safe'} ({safety_conf:.2%})" # === Dating sentiment (single message only) === dating_inputs = dating_tokenizer(msg, return_tensors="pt", truncation=True, padding=True, max_length=512) dating_inputs = {k: v.to(device) for k, v in dating_inputs.items()} with torch.no_grad(): dating_outputs = dating_model(**dating_inputs) dating_probs = torch.nn.functional.softmax(dating_outputs.logits, dim=1) dating_label = torch.argmax(dating_probs).item() dating_conf = dating_probs[0][dating_label].item() dating_result = f"{dating_labels[dating_label].capitalize()} ({dating_conf:.2%})" results.append((f"Message {idx+1}: {msg}", safety_result, dating_result)) return [r for tup in results for r in tup] # --- Gradio UI --- demo = gr.Interface( fn=classify_text_all, inputs=[ gr.Textbox(lines=2, label="User A"), gr.Textbox(lines=2, label="User B"), gr.Textbox(lines=2, label="User A"), gr.Textbox(lines=2, label="User B"), gr.Textbox(lines=2, label="User A"), gr.Textbox(lines=2, label="User B"), gr.Checkbox(label="Use context (previous messages)", value=True), ], outputs=[ gr.Textbox(label="🗨️ User A"), gr.Textbox(label="🔐 Safety User A"), gr.Textbox(label=" Sentiment User A"), gr.Textbox(label="🗨️ User B"), gr.Textbox(label="🔐 Safety User B"), gr.Textbox(label=" Sentiment User B"), gr.Textbox(label="🗨️ User A"), gr.Textbox(label="🔐 Safety User A"), gr.Textbox(label=" Sentiment User A"), gr.Textbox(label="🗨️ User B"), gr.Textbox(label="🔐 Safety User B"), gr.Textbox(label=" Sentiment User B"), gr.Textbox(label="🗨️ User A"), gr.Textbox(label="🔐 Safety User A"), gr.Textbox(label=" Sentiment User A"), gr.Textbox(label="🗨️ User B"), gr.Textbox(label="🔐 Safety User B"), gr.Textbox(label=" Sentiment User B"), ], title="🧠 Multi-Model Analyzer (English Only)", description=( "This interface evaluates **six messages** using **two models**:\n\n" "1. **Safety Classifier** (RoBERTa) fine-tuned with additional datasets 22k+ with over 3000k+ conversations.\n\n" "2. **Sentiment Classifier** (RoBERTa) fine-tuned with additional datasets 102k+.\n\n" "📌 Both models require continuous training on diverse, high-quality datasets to better capture the nuances of conversations." ) ) if __name__ == "__main__": demo.launch()