bobbysam commited on
Commit
58e5879
Β·
verified Β·
1 Parent(s): 959b26a

Model save

Browse files
Files changed (4) hide show
  1. README.md +78 -244
  2. hf_config.py +64 -0
  3. hf_model.py +179 -0
  4. model.safetensors +1 -1
README.md CHANGED
@@ -1,255 +1,89 @@
1
  ---
2
- license: apache-2.0
3
- language:
4
- - en
5
- base_model:
6
- - bobbysam/resnet18-image-detector
7
  library_name: transformers
8
- pipeline_tag: image-classification
9
  tags:
10
- - computer-vision
11
- - image-classification
12
- - ai-detection
13
- - pytorch
14
- - resnet
15
- datasets:
16
- - custom
17
  metrics:
18
  - accuracy
 
19
  - precision
20
  - recall
21
- - f1
22
  model-index:
23
  - name: resnet18-image-detector
24
- results:
25
- - task:
26
- type: image-classification
27
- name: AI vs Real Image Detection
28
- dataset:
29
- name: Custom AI Detection Dataset
30
- type: custom
31
- metrics:
32
- - type: accuracy
33
- value: 0.95
34
- name: Accuracy
35
- - type: f1
36
- value: 0.94
37
- name: F1 Score
38
- - type: precision
39
- value: 0.93
40
- name: Precision
41
- - type: recall
42
- value: 0.96
43
- name: Recall
44
- ---
45
- # ResNet18 AI Image Detector
46
-
47
- **Repository:** [bobbysam/resnet18-image-detector](https://huggingface.co/bobbysam/resnet18-image-detector)
48
-
49
- [![Train](https://huggingface.co/datasets/huggingface/badges/raw/main/train-on-spaces-sm.svg)](https://huggingface.co/spaces/autotrain-projects/train-resnet18-detector)
50
- [![Deploy](https://huggingface.co/datasets/huggingface/badges/raw/main/deploy-on-spaces-sm.svg)](https://huggingface.co/spaces/autotrain-projects/deploy-resnet18-detector)
51
-
52
- ---
53
-
54
- ## 🧠 What does this model do?
55
-
56
- This is a **ResNet18-based deep neural network** trained to **detect whether an input image is a real photograph or AI-generated** (binary classification: `real` vs. `ai_generated`).
57
- It is part of the [ProofGuard](https://github.com/Proofguard/proofguard-backend) project and can be used to build trustworthy AI image detection pipelines.
58
-
59
- **Key Features:**
60
- - πŸ”¬ Binary classification: Real vs AI-generated images
61
- - πŸš€ Fast inference with ResNet18 architecture
62
- - πŸ€— Compatible with Hugging Face Transformers
63
- - πŸ“Š Comprehensive evaluation metrics
64
- - 🎯 Easy-to-use inference API
65
-
66
- ---
67
-
68
- ## πŸš€ Quick Start
69
-
70
- ### **Option 1: Using Hugging Face Transformers (Recommended)**
71
-
72
- ```python
73
- from transformers import AutoModelForImageClassification, AutoImageProcessor
74
- from PIL import Image
75
- import torch
76
-
77
- # Load model and processor
78
- model = AutoModelForImageClassification.from_pretrained("bobbysam/resnet18-image-detector")
79
- processor = AutoImageProcessor.from_pretrained("bobbysam/resnet18-image-detector")
80
-
81
- # Load and process image
82
- image = Image.open("your_image.jpg")
83
- inputs = processor(image, return_tensors="pt")
84
-
85
- # Make prediction
86
- with torch.no_grad():
87
- outputs = model(**inputs)
88
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
89
- prediction = torch.argmax(probabilities, dim=-1).item()
90
-
91
- labels = ["Real", "AI-generated"]
92
- confidence = probabilities[0, prediction].item()
93
- print(f"Prediction: {labels[prediction]} (Confidence: {confidence:.2%})")
94
- ```
95
-
96
- ### **Option 2: Using the Inference Script**
97
-
98
- ```bash
99
- # Clone the repository
100
- git clone https://huggingface.co/bobbysam/resnet18-image-detector
101
- cd resnet18-image-detector
102
-
103
- # Install dependencies
104
- pip install -r requirements.txt
105
-
106
- # Run inference
107
- python inference.py --image path/to/your/image.jpg --model ./
108
- ```
109
-
110
- ### **Option 3: Using the Custom Wrapper**
111
-
112
- ```python
113
- from inference import AIImageDetector
114
-
115
- # Initialize detector
116
- detector = AIImageDetector()
117
-
118
- # Make prediction
119
- result = detector.predict("your_image.jpg")
120
- print(f"Prediction: {result['prediction']}")
121
- print(f"Confidence: {result['confidence']:.2%}")
122
- ``` ---
123
-
124
- ## πŸ‹οΈ Training Your Own Model
125
-
126
- ### **Quick Training with Hugging Face Trainer**
127
-
128
- ```bash
129
- # 1. Setup the environment
130
- python setup.py
131
-
132
- # 2. Download/prepare your dataset
133
- python download_dataset.py --dataset_type custom --source_dir /path/to/your/data
134
-
135
- # 3. Train the model
136
- python trainer.py \
137
- --data_dir ./data \
138
- --output_dir ./results \
139
- --num_epochs 10 \
140
- --batch_size 16 \
141
- --push_to_hub \
142
- --hub_model_id your-username/resnet18-detector
143
- ```
144
-
145
- ### **Training Arguments**
146
-
147
- | Argument | Description | Default |
148
- |----------|-------------|---------|
149
- | `--data_dir` | Path to dataset directory | Required |
150
- | `--output_dir` | Output directory for model | `./results` |
151
- | `--num_epochs` | Number of training epochs | 10 |
152
- | `--batch_size` | Training batch size | 16 |
153
- | `--learning_rate` | Learning rate | 2e-5 |
154
- | `--dropout_rate` | Dropout rate for regularization | 0.5 |
155
- | `--freeze_backbone` | Freeze ResNet backbone | False |
156
- | `--push_to_hub` | Push model to HF Hub | False |
157
- | `--hub_model_id` | Hugging Face model ID | None |
158
-
159
- ### **Dataset Structure**
160
-
161
- Your dataset should be organized as follows:
162
- ```
163
- data/
164
- β”œβ”€β”€ real/
165
- β”‚ β”œβ”€β”€ image1.jpg
166
- β”‚ β”œβ”€β”€ image2.jpg
167
- β”‚ └── ...
168
- └── ai_generated/
169
- β”œβ”€β”€ image1.jpg
170
- β”œβ”€β”€ image2.jpg
171
- └── ...
172
- ```
173
-
174
- ---
175
-
176
- ## πŸš€ Deployment Options
177
-
178
- This model supports multiple deployment options through Hugging Face:
179
-
180
- ### **1. Hugging Face Inference Endpoints**
181
- - Production-ready inference API
182
- - Auto-scaling and load balancing
183
- - Pay-per-request pricing
184
-
185
- ### **2. Amazon SageMaker**
186
- - Deploy directly to AWS SageMaker
187
- - Enterprise-grade infrastructure
188
- - Custom scaling policies
189
-
190
- ### **3. Azure ML**
191
- - Deploy to Azure Machine Learning
192
- - Integration with Azure services
193
- - Enterprise security features
194
-
195
- ### **4. Local Deployment**
196
- ```python
197
- # Load model locally
198
- from transformers import pipeline
199
-
200
- classifier = pipeline(
201
- "image-classification",
202
- model="bobbysam/resnet18-image-detector",
203
- device=0 if torch.cuda.is_available() else -1
204
- )
205
-
206
- result = classifier("path/to/image.jpg")
207
- ```
208
-
209
- ---
210
-
211
- ## πŸ“₯ Input format and requirements
212
-
213
- - **Input:** RGB image (PIL Image or file path), resized to 224x224, normalized as in ImageNet.
214
- - **Output:**
215
- - `0` = Real photograph
216
- - `1` = AI-generated image
217
-
218
- ---
219
-
220
- ## πŸ“¦ Model details
221
-
222
- - **Architecture:** ResNet18 (PyTorch, torchvision)
223
- - **Training data:** Real & AI-generated images (see [ProofGuard project](https://github.com/Proofguard/proofguard-backend))
224
- - **Framework:** PyTorch
225
- - **Size:** ~60MB
226
-
227
- ---
228
-
229
- ## βš–οΈ License and usage
230
-
231
- - **License:** [MIT](https://opensource.org/license/mit/) (or specify your own)
232
- - **Usage restrictions:** For research, education, and non-commercial projects.
233
- _For commercial use, contact the author or check the ProofGuard project license._
234
-
235
- ---
236
-
237
- ## πŸ™ Citation
238
-
239
- If you use this model, please cite:
240
- ```text
241
- ProofGuard: AI Image Authenticity Detection
242
- https://github.com/Proofguard/proofguard-backend
243
- Model by @bobbysam (Hugging Face)
244
- ```
245
-
246
- ---
247
-
248
- ## πŸ› οΈ Maintainer
249
-
250
- - [@bobbysam](https://huggingface.co/bobbysam)
251
- - [ProofGuard GitHub](https://github.com/Proofguard/proofguard-backend)
252
-
253
  ---
254
 
255
- *Feel free to open issues or PRs on the [ProofGuard repo](https://github.com/Proofguard/proofguard-backend) for improvements or questions!*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
 
 
 
 
 
2
  library_name: transformers
 
3
  tags:
4
+ - generated_from_trainer
 
 
 
 
 
 
5
  metrics:
6
  - accuracy
7
+ - f1
8
  - precision
9
  - recall
 
10
  model-index:
11
  - name: resnet18-image-detector
12
+ results: []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
16
+ should probably proofread and complete it, then remove this comment. -->
17
+
18
+ # resnet18-image-detector
19
+
20
+ This model is a fine-tuned version of [](https://huggingface.co/) on the None dataset.
21
+ It achieves the following results on the evaluation set:
22
+ - Loss: 0.2461
23
+ - Accuracy: 0.9738
24
+ - F1: 0.9737
25
+ - Precision: 0.9739
26
+ - Recall: 0.9738
27
+
28
+ ## Model description
29
+
30
+ More information needed
31
+
32
+ ## Intended uses & limitations
33
+
34
+ More information needed
35
+
36
+ ## Training and evaluation data
37
+
38
+ More information needed
39
+
40
+ ## Training procedure
41
+
42
+ ### Training hyperparameters
43
+
44
+ The following hyperparameters were used during training:
45
+ - learning_rate: 0.0001
46
+ - train_batch_size: 16
47
+ - eval_batch_size: 16
48
+ - seed: 42
49
+ - gradient_accumulation_steps: 2
50
+ - total_train_batch_size: 32
51
+ - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
52
+ - lr_scheduler_type: cosine_with_restarts
53
+ - lr_scheduler_warmup_ratio: 0.1
54
+ - num_epochs: 3
55
+
56
+ ### Training results
57
+
58
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | F1 | Precision | Recall |
59
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:------:|:---------:|:------:|
60
+ | 1.3887 | 0.0533 | 50 | 0.6371 | 0.7338 | 0.7336 | 0.7345 | 0.7338 |
61
+ | 1.1433 | 0.1067 | 100 | 0.4604 | 0.8571 | 0.8569 | 0.8591 | 0.8571 |
62
+ | 0.8563 | 0.16 | 150 | 0.3538 | 0.9081 | 0.9080 | 0.9094 | 0.9081 |
63
+ | 0.7671 | 0.2133 | 200 | 0.3244 | 0.9277 | 0.9277 | 0.9282 | 0.9277 |
64
+ | 0.7213 | 0.2667 | 250 | 0.3244 | 0.9301 | 0.9300 | 0.9307 | 0.9301 |
65
+ | 0.6996 | 0.32 | 300 | 0.3187 | 0.9324 | 0.9323 | 0.9339 | 0.9324 |
66
+ | 0.6975 | 0.3733 | 350 | 0.3429 | 0.9193 | 0.9189 | 0.9268 | 0.9193 |
67
+ | 0.7327 | 0.4267 | 400 | 0.2890 | 0.9520 | 0.9520 | 0.9523 | 0.9520 |
68
+ | 0.7072 | 0.48 | 450 | 0.2939 | 0.9460 | 0.9460 | 0.9475 | 0.9460 |
69
+ | 0.666 | 0.5333 | 500 | 0.2886 | 0.9506 | 0.9505 | 0.9509 | 0.9506 |
70
+ | 0.6596 | 0.5867 | 550 | 0.2800 | 0.9543 | 0.9543 | 0.9550 | 0.9543 |
71
+ | 0.6394 | 0.64 | 600 | 0.2800 | 0.9523 | 0.9522 | 0.9524 | 0.9523 |
72
+ | 0.6734 | 0.6933 | 650 | 0.2740 | 0.9579 | 0.9579 | 0.9586 | 0.9579 |
73
+ | 0.6467 | 0.7467 | 700 | 0.2727 | 0.9582 | 0.9582 | 0.9595 | 0.9582 |
74
+ | 0.6662 | 0.8 | 750 | 0.2711 | 0.9585 | 0.9585 | 0.9586 | 0.9585 |
75
+ | 0.5994 | 0.8533 | 800 | 0.2625 | 0.9656 | 0.9656 | 0.9656 | 0.9656 |
76
+ | 0.6189 | 0.9067 | 850 | 0.2843 | 0.95 | 0.9500 | 0.9511 | 0.95 |
77
+ | 0.6317 | 0.96 | 900 | 0.2600 | 0.9651 | 0.9651 | 0.9658 | 0.9651 |
78
+ | 0.5973 | 1.0128 | 950 | 0.2497 | 0.9733 | 0.9733 | 0.9733 | 0.9733 |
79
+ | 0.5592 | 1.0661 | 1000 | 0.2461 | 0.9738 | 0.9737 | 0.9739 | 0.9738 |
80
+ | 0.6093 | 1.1195 | 1050 | 0.2705 | 0.9567 | 0.9567 | 0.9590 | 0.9567 |
81
+ | 0.5505 | 1.1728 | 1100 | 0.2465 | 0.9716 | 0.9716 | 0.9718 | 0.9716 |
82
+
83
+
84
+ ### Framework versions
85
+
86
+ - Transformers 4.54.1
87
+ - Pytorch 2.7.1+cu126
88
+ - Datasets 4.0.0
89
+ - Tokenizers 0.21.4
hf_config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face compatible configuration for existing Space
3
+ This extends your existing config without breaking it
4
+ """
5
+
6
+ try:
7
+ from transformers import PretrainedConfig
8
+ TRANSFORMERS_AVAILABLE = True
9
+ except ImportError:
10
+ TRANSFORMERS_AVAILABLE = False
11
+ # Fallback configuration
12
+ class PretrainedConfig:
13
+ def __init__(self, **kwargs):
14
+ for key, value in kwargs.items():
15
+ setattr(self, key, value)
16
+
17
+
18
+ class HFResNet18DetectorConfig(PretrainedConfig):
19
+ """
20
+ Hugging Face compatible configuration for your existing model
21
+ Works alongside your existing training config
22
+ """
23
+
24
+ model_type = "resnet18-detector"
25
+
26
+ def __init__(
27
+ self,
28
+ num_classes: int = 2,
29
+ image_size: int = 224,
30
+ architecture: str = "resnet18",
31
+ dropout_rate: float = 0.5,
32
+ freeze_backbone: bool = False,
33
+ pretrained_weights: str = "IMAGENET1K_V1",
34
+ label_smoothing: float = 0.1, # Anti-overfitting: Label smoothing
35
+ weight_decay: float = 0.1, # Anti-overfitting: L2 regularization
36
+ max_grad_norm: float = 1.0, # Anti-overfitting: Gradient clipping
37
+ **kwargs
38
+ ):
39
+ """
40
+ Initialize HF compatible config with anti-overfitting parameters
41
+ """
42
+ self.num_classes = num_classes
43
+ self.image_size = image_size
44
+ self.architecture = architecture
45
+ self.dropout_rate = dropout_rate
46
+ self.freeze_backbone = freeze_backbone
47
+ self.pretrained_weights = pretrained_weights
48
+ self.label_smoothing = label_smoothing
49
+ self.weight_decay = weight_decay
50
+ self.max_grad_norm = max_grad_norm
51
+
52
+ if TRANSFORMERS_AVAILABLE:
53
+ super().__init__(**kwargs)
54
+ else:
55
+ for key, value in kwargs.items():
56
+ setattr(self, key, value)
57
+
58
+
59
+ # Register for auto-loading if transformers is available
60
+ if TRANSFORMERS_AVAILABLE:
61
+ try:
62
+ HFResNet18DetectorConfig.register_for_auto_class()
63
+ except:
64
+ pass
hf_model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face compatible model wrapper for your existing Space
3
+ This works alongside your existing model loading without breaking it
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Optional
9
+ import sys
10
+ import os
11
+
12
+ # Import transformers components if available
13
+ try:
14
+ from transformers import PreTrainedModel
15
+ from transformers.modeling_outputs import ImageClassifierOutput
16
+ TRANSFORMERS_AVAILABLE = True
17
+ except ImportError:
18
+ TRANSFORMERS_AVAILABLE = False
19
+ # Fallback classes
20
+ class PreTrainedModel(nn.Module):
21
+ def __init__(self, config):
22
+ super().__init__()
23
+ self.config = config
24
+
25
+ class ImageClassifierOutput:
26
+ def __init__(self, loss=None, logits=None):
27
+ self.loss = loss
28
+ self.logits = logits
29
+
30
+ # Import your existing components
31
+ sys.path.append(os.path.join(os.path.dirname(__file__), "training"))
32
+
33
+ try:
34
+ from hf_config import HFResNet18DetectorConfig
35
+ except ImportError:
36
+ # Fallback config
37
+ class HFResNet18DetectorConfig:
38
+ def __init__(self, num_classes=2, **kwargs):
39
+ self.num_classes = num_classes
40
+ for key, value in kwargs.items():
41
+ setattr(self, key, value)
42
+
43
+
44
+ class HFResNet18Detector(PreTrainedModel):
45
+ """
46
+ Hugging Face compatible wrapper for your existing model
47
+ This allows your model to work with HF Trainer and ecosystem
48
+ """
49
+
50
+ config_class = HFResNet18DetectorConfig
51
+
52
+ def __init__(self, config: HFResNet18DetectorConfig):
53
+ super().__init__(config)
54
+
55
+ self.num_labels = getattr(config, 'num_classes', 2)
56
+ self.config = config
57
+
58
+ # Try to use your existing model creation logic first
59
+ try:
60
+ from training.detection_models import create_model
61
+ from training.config import get_model_config
62
+
63
+ model_config = get_model_config("resnet18")
64
+ self.backbone = create_model("resnet18", model_config)
65
+ print("[HF Model] Using existing model creation logic")
66
+
67
+ except Exception as e:
68
+ print(f"[HF Model] Fallback to basic ResNet18: {e}")
69
+ # Fallback to basic ResNet18
70
+ from torchvision.models import resnet18, ResNet18_Weights
71
+
72
+ weights = ResNet18_Weights.IMAGENET1K_V1
73
+ self.backbone = resnet18(weights=weights)
74
+
75
+ # Replace final layer with enhanced regularization
76
+ in_features = self.backbone.fc.in_features
77
+ dropout_rate = getattr(config, 'dropout_rate', 0.5)
78
+ num_classes = getattr(config, 'num_classes', 2)
79
+
80
+ # Multi-layer classification head with stronger regularization
81
+ self.backbone.fc = nn.Sequential(
82
+ nn.Dropout(dropout_rate),
83
+ nn.Linear(in_features, 512),
84
+ nn.ReLU(),
85
+ nn.BatchNorm1d(512),
86
+ nn.Dropout(0.6), # Higher dropout for intermediate layer
87
+ nn.Linear(512, 256),
88
+ nn.ReLU(),
89
+ nn.BatchNorm1d(256),
90
+ nn.Dropout(0.7), # Even higher dropout near output
91
+ nn.Linear(256, num_classes)
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ pixel_values: Optional[torch.Tensor] = None,
97
+ labels: Optional[torch.Tensor] = None,
98
+ return_dict: Optional[bool] = None,
99
+ **kwargs
100
+ ):
101
+ """
102
+ Forward pass compatible with both HF and your existing code
103
+ """
104
+ # Handle both HF format and your existing format
105
+ if pixel_values is None:
106
+ raise ValueError("pixel_values must be provided")
107
+
108
+ # Forward pass through your existing model
109
+ logits = self.backbone(pixel_values)
110
+
111
+ loss = None
112
+ if labels is not None:
113
+ # Ensure labels are properly formatted
114
+ if isinstance(labels, torch.Tensor):
115
+ labels = labels.long()
116
+ else:
117
+ labels = torch.tensor(labels, dtype=torch.long)
118
+
119
+ # Ensure labels are 1D
120
+ if labels.dim() > 1:
121
+ labels = labels.squeeze()
122
+
123
+ # Use label smoothing to combat overfitting with proper error handling
124
+ try:
125
+ label_smoothing = getattr(self.config, 'label_smoothing', 0.1)
126
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
127
+ loss = loss_fct(logits, labels)
128
+ except Exception as e:
129
+ print(f"[HF Model] Label smoothing failed ({e}), falling back to standard CrossEntropyLoss")
130
+ # Fallback to standard cross entropy if label smoothing fails
131
+ loss_fct = nn.CrossEntropyLoss()
132
+ loss = loss_fct(logits, labels)
133
+
134
+ if TRANSFORMERS_AVAILABLE and return_dict:
135
+ return ImageClassifierOutput(
136
+ loss=loss,
137
+ logits=logits,
138
+ )
139
+ else:
140
+ # Fallback for non-HF usage
141
+ if loss is not None:
142
+ return loss, logits
143
+ return logits
144
+
145
+ def predict_compatibility(self, x):
146
+ """
147
+ Compatibility method for your existing inference code
148
+ """
149
+ return self.backbone(x)
150
+
151
+
152
+ # Register for auto-loading if transformers is available
153
+ if TRANSFORMERS_AVAILABLE:
154
+ try:
155
+ HFResNet18Detector.register_for_auto_class("AutoModelForImageClassification")
156
+ except:
157
+ pass
158
+
159
+
160
+ def create_hf_compatible_model(existing_model_path=None):
161
+ """
162
+ Helper function to create HF compatible model from existing weights
163
+ """
164
+ config = HFResNet18DetectorConfig()
165
+ model = HFResNet18Detector(config)
166
+
167
+ if existing_model_path and os.path.exists(existing_model_path):
168
+ try:
169
+ # Load your existing model weights
170
+ checkpoint = torch.load(existing_model_path, map_location="cpu", weights_only=False)
171
+ if 'model_state_dict' in checkpoint:
172
+ model.backbone.load_state_dict(checkpoint['model_state_dict'])
173
+ else:
174
+ model.backbone.load_state_dict(checkpoint)
175
+ print(f"[HF Model] Loaded weights from {existing_model_path}")
176
+ except Exception as e:
177
+ print(f"[HF Model] Failed to load weights: {e}")
178
+
179
+ return model
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:03c332a43edc3107efc8fc01433ec598a08a0dab97d04e8c25f48b6a637c2efa
3
  size 45284592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d66bdfeedccc35f11a7b63a1c1864eae577135e56fa9cb8e00a828e6ce274d4d
3
  size 45284592