import gradio as gr import torch import torchvision.transforms as transforms from medmnist import INFO from model import load_model from PIL import Image # Class names info = INFO["dermamnist"] class_names = list(info["label"].values()) # Load model model = load_model() # Transforms (match training) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Prediction function def predict(image): if image is None: return {"Error": 1.0} image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(input_tensor) probs = torch.softmax(outputs, dim=1).squeeze().numpy() return {class_names[i]: float(probs[i]) for i in range(len(class_names))} # Gradio UI demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Skin Disease Classifier", description="Upload a skin image and the model will predict potential skin cancer(melanoma), tumor or moles using EfficientNet-B2 fine-tuned on DermMNIST." ) if __name__ == "__main__": demo.launch()