import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import torch # Load the model + processor model_id = "facebook/deit-base-patch16-224" model = AutoModelForImageClassification.from_pretrained(model_id) processor = AutoImageProcessor.from_pretrained(model_id) # Define prediction function def classify_image(image): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() label = model.config.id2label[predicted_class_idx] return label # Create Gradio interface demo = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="Animal Species Classifier", description="Using facebook/deit-base-patch16-224" ) demo.launch()