import torch import torchvision from torchvision.models import EfficientNet_B2_Weights from torch import nn def create_model(num_classes=7): weights = EfficientNet_B2_Weights.DEFAULT model = torchvision.models.efficientnet_b2(weights=weights) for param in model.parameters(): param.requires_grad = False # Freeze for inference model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(model.classifier[1].in_features, num_classes) ) return model def load_model(weights_path="model/effnetb2_dermamnist.pth"): model = create_model(num_classes=7) model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) model.eval() return model