import torch import torch.nn.functional as F from transformers import AutoTokenizer from huggingface_hub import hf_hub_download import gradio as gr # --- import your architecture --- # Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) # and update the import path accordingly. from model import DeBERTaLSTMClassifier # <-- your class # --------- Config ---------- REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone LABELS = ["benign", "phishing"] # adjust to your classes # If your checkpoint contains hyperparams, you can fetch them like: # checkpoint.get("config") or checkpoint.get("model_args") # and pass into DeBERTaLSTMClassifier(**model_args) # --------- Load model/tokenizer once (global) ---------- device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) checkpoint = torch.load(ckpt_path, map_location=device) # If you saved hyperparams in the checkpoint, use them: model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} model = DeBERTaLSTMClassifier(**model_args) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device).eval() # --------- Inference function ---------- def predict_fn(text: str): if not text or not text.strip(): return {"error": "Please enter a URL or text."} # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, # single example -> becomes [1, seq_len] max_length=256 # adjust as used during training ) # DeBERTa typically doesn't use token_type_ids inputs.pop("token_type_ids", None) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs) # your model.forward should accept (input_ids, attention_mask) probs = F.softmax(logits, dim=-1).squeeze(0).tolist() # Build label->prob mapping for Gradio Label output # If LABELS length doesn't match logits dim, just return raw list if len(LABELS) == len(probs): return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} else: return {f"class_{i}": float(p) for i, p in enumerate(probs)} # --------- Gradio UI ---------- demo = gr.Interface( fn=predict_fn, inputs=gr.Textbox(label="URL or text", placeholder="e.g., http://suspicious-site.example"), outputs=gr.Label(label="Prediction"), title="Phishing Detector (DeBERTa + LSTM)", description="Enter a URL/text. The model outputs class probabilities.", examples=[ ["http://rendmoiunserviceeee.com"], ["https://www.google.com"], ["https://mail-secure-login-verify.example/path?token=..."] ] ) if __name__ == "__main__": demo.launch()