""" Hugging Face compatible model wrapper for your existing Space This works alongside your existing model loading without breaking it """ import torch import torch.nn as nn from typing import Optional import sys import os # Import transformers components if available try: from transformers import PreTrainedModel from transformers.modeling_outputs import ImageClassifierOutput TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False # Fallback classes class PreTrainedModel(nn.Module): def __init__(self, config): super().__init__() self.config = config class ImageClassifierOutput: def __init__(self, loss=None, logits=None): self.loss = loss self.logits = logits # Import your existing components sys.path.append(os.path.join(os.path.dirname(__file__), "training")) try: from hf_config import HFResNet18DetectorConfig except ImportError: # Fallback config class HFResNet18DetectorConfig: def __init__(self, num_classes=2, **kwargs): self.num_classes = num_classes for key, value in kwargs.items(): setattr(self, key, value) class HFResNet18Detector(PreTrainedModel): """ Hugging Face compatible wrapper for your existing model This allows your model to work with HF Trainer and ecosystem """ config_class = HFResNet18DetectorConfig def __init__(self, config: HFResNet18DetectorConfig): super().__init__(config) self.num_labels = getattr(config, 'num_classes', 2) self.config = config # Try to use your existing model creation logic first try: from training.detection_models import create_model from training.config import get_model_config model_config = get_model_config("resnet18") self.backbone = create_model("resnet18", model_config) print("[HF Model] Using existing model creation logic") except Exception as e: print(f"[HF Model] Fallback to basic ResNet18: {e}") # Fallback to basic ResNet18 from torchvision.models import resnet18, ResNet18_Weights weights = ResNet18_Weights.IMAGENET1K_V1 self.backbone = resnet18(weights=weights) # Replace final layer with enhanced regularization in_features = self.backbone.fc.in_features dropout_rate = getattr(config, 'dropout_rate', 0.5) num_classes = getattr(config, 'num_classes', 2) # Multi-layer classification head with stronger regularization self.backbone.fc = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(in_features, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.6), # Higher dropout for intermediate layer nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(0.7), # Even higher dropout near output nn.Linear(256, num_classes) ) def forward( self, pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs ): """ Forward pass compatible with both HF and your existing code """ # Handle both HF format and your existing format if pixel_values is None: raise ValueError("pixel_values must be provided") # Forward pass through your existing model logits = self.backbone(pixel_values) loss = None if labels is not None: # Ensure labels are properly formatted if isinstance(labels, torch.Tensor): labels = labels.long() else: labels = torch.tensor(labels, dtype=torch.long) # Ensure labels are 1D if labels.dim() > 1: labels = labels.squeeze() # Use label smoothing to combat overfitting with proper error handling try: label_smoothing = getattr(self.config, 'label_smoothing', 0.1) loss_fct = nn.CrossEntropyLoss(label_smoothing=label_smoothing) loss = loss_fct(logits, labels) except Exception as e: print(f"[HF Model] Label smoothing failed ({e}), falling back to standard CrossEntropyLoss") # Fallback to standard cross entropy if label smoothing fails loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) if TRANSFORMERS_AVAILABLE and return_dict: return ImageClassifierOutput( loss=loss, logits=logits, ) else: # Fallback for non-HF usage if loss is not None: return loss, logits return logits def predict_compatibility(self, x): """ Compatibility method for your existing inference code """ return self.backbone(x) # Register for auto-loading if transformers is available if TRANSFORMERS_AVAILABLE: try: HFResNet18Detector.register_for_auto_class("AutoModelForImageClassification") except: pass def create_hf_compatible_model(existing_model_path=None): """ Helper function to create HF compatible model from existing weights """ config = HFResNet18DetectorConfig() model = HFResNet18Detector(config) if existing_model_path and os.path.exists(existing_model_path): try: # Load your existing model weights checkpoint = torch.load(existing_model_path, map_location="cpu", weights_only=False) if 'model_state_dict' in checkpoint: model.backbone.load_state_dict(checkpoint['model_state_dict']) else: model.backbone.load_state_dict(checkpoint) print(f"[HF Model] Loaded weights from {existing_model_path}") except Exception as e: print(f"[HF Model] Failed to load weights: {e}") return model