Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
# --- import your architecture --- | |
# Make sure this file is in the repo (e.g., models/deberta_lstm_classifier.py) | |
# and update the import path accordingly. | |
from model import DeBERTaLSTMClassifier # <-- your class | |
# --------- Config ---------- | |
REPO_ID = "khoa-done/phishing-detector" # HF repo that holds the checkpoint | |
CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name | |
MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone | |
LABELS = ["benign", "phishing"] # adjust to your classes | |
# If your checkpoint contains hyperparams, you can fetch them like: | |
# checkpoint.get("config") or checkpoint.get("model_args") | |
# and pass into DeBERTaLSTMClassifier(**model_args) | |
# --------- Load model/tokenizer once (global) ---------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME) | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
# If you saved hyperparams in the checkpoint, use them: | |
model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...} | |
model = DeBERTaLSTMClassifier(**model_args) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
model.to(device).eval() | |
# --------- Inference function ---------- | |
def predict_fn(text: str): | |
if not text or not text.strip(): | |
return {"error": "Please enter a URL or text."} | |
# Tokenize | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, # single example -> becomes [1, seq_len] | |
max_length=256 # adjust as used during training | |
) | |
# DeBERTa typically doesn't use token_type_ids | |
inputs.pop("token_type_ids", None) | |
# Move to device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
logits = model(**inputs) # your model.forward should accept (input_ids, attention_mask) | |
probs = F.softmax(logits, dim=-1).squeeze(0).tolist() | |
# Build label->prob mapping for Gradio Label output | |
# If LABELS length doesn't match logits dim, just return raw list | |
if len(LABELS) == len(probs): | |
return {LABELS[i]: float(probs[i]) for i in range(len(LABELS))} | |
else: | |
return {f"class_{i}": float(p) for i, p in enumerate(probs)} | |
# --------- Gradio UI ---------- | |
demo = gr.Interface( | |
fn=predict_fn, | |
inputs=gr.Textbox(label="URL or text", placeholder="e.g., http://suspicious-site.example"), | |
outputs=gr.Label(label="Prediction"), | |
title="Phishing Detector (DeBERTa + LSTM)", | |
description="Enter a URL/text. The model outputs class probabilities.", | |
examples=[ | |
["http://rendmoiunserviceeee.com"], | |
["https://www.google.com"], | |
["https://mail-secure-login-verify.example/path?token=..."] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |