|
import torch |
|
import torchvision |
|
from torchvision.models import EfficientNet_B2_Weights |
|
from torch import nn |
|
|
|
def create_model(num_classes=7): |
|
weights = EfficientNet_B2_Weights.DEFAULT |
|
model = torchvision.models.efficientnet_b2(weights=weights) |
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
model.classifier = nn.Sequential( |
|
nn.Dropout(p=0.3), |
|
nn.Linear(model.classifier[1].in_features, num_classes) |
|
) |
|
|
|
return model |
|
|
|
def load_model(weights_path="model/effnetb2_dermamnist.pth"): |
|
model = create_model(num_classes=7) |
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) |
|
model.eval() |
|
return model |
|
|