Model save
Browse files- README.md +78 -244
- hf_config.py +64 -0
- hf_model.py +179 -0
- model.safetensors +1 -1
README.md
CHANGED
@@ -1,255 +1,89 @@
|
|
1 |
---
|
2 |
-
license: apache-2.0
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
base_model:
|
6 |
-
- bobbysam/resnet18-image-detector
|
7 |
library_name: transformers
|
8 |
-
pipeline_tag: image-classification
|
9 |
tags:
|
10 |
-
-
|
11 |
-
- image-classification
|
12 |
-
- ai-detection
|
13 |
-
- pytorch
|
14 |
-
- resnet
|
15 |
-
datasets:
|
16 |
-
- custom
|
17 |
metrics:
|
18 |
- accuracy
|
|
|
19 |
- precision
|
20 |
- recall
|
21 |
-
- f1
|
22 |
model-index:
|
23 |
- name: resnet18-image-detector
|
24 |
-
results:
|
25 |
-
- task:
|
26 |
-
type: image-classification
|
27 |
-
name: AI vs Real Image Detection
|
28 |
-
dataset:
|
29 |
-
name: Custom AI Detection Dataset
|
30 |
-
type: custom
|
31 |
-
metrics:
|
32 |
-
- type: accuracy
|
33 |
-
value: 0.95
|
34 |
-
name: Accuracy
|
35 |
-
- type: f1
|
36 |
-
value: 0.94
|
37 |
-
name: F1 Score
|
38 |
-
- type: precision
|
39 |
-
value: 0.93
|
40 |
-
name: Precision
|
41 |
-
- type: recall
|
42 |
-
value: 0.96
|
43 |
-
name: Recall
|
44 |
-
---
|
45 |
-
# ResNet18 AI Image Detector
|
46 |
-
|
47 |
-
**Repository:** [bobbysam/resnet18-image-detector](https://huggingface.co/bobbysam/resnet18-image-detector)
|
48 |
-
|
49 |
-
[](https://huggingface.co/spaces/autotrain-projects/train-resnet18-detector)
|
50 |
-
[](https://huggingface.co/spaces/autotrain-projects/deploy-resnet18-detector)
|
51 |
-
|
52 |
-
---
|
53 |
-
|
54 |
-
## π§ What does this model do?
|
55 |
-
|
56 |
-
This is a **ResNet18-based deep neural network** trained to **detect whether an input image is a real photograph or AI-generated** (binary classification: `real` vs. `ai_generated`).
|
57 |
-
It is part of the [ProofGuard](https://github.com/Proofguard/proofguard-backend) project and can be used to build trustworthy AI image detection pipelines.
|
58 |
-
|
59 |
-
**Key Features:**
|
60 |
-
- π¬ Binary classification: Real vs AI-generated images
|
61 |
-
- π Fast inference with ResNet18 architecture
|
62 |
-
- π€ Compatible with Hugging Face Transformers
|
63 |
-
- π Comprehensive evaluation metrics
|
64 |
-
- π― Easy-to-use inference API
|
65 |
-
|
66 |
-
---
|
67 |
-
|
68 |
-
## π Quick Start
|
69 |
-
|
70 |
-
### **Option 1: Using Hugging Face Transformers (Recommended)**
|
71 |
-
|
72 |
-
```python
|
73 |
-
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
74 |
-
from PIL import Image
|
75 |
-
import torch
|
76 |
-
|
77 |
-
# Load model and processor
|
78 |
-
model = AutoModelForImageClassification.from_pretrained("bobbysam/resnet18-image-detector")
|
79 |
-
processor = AutoImageProcessor.from_pretrained("bobbysam/resnet18-image-detector")
|
80 |
-
|
81 |
-
# Load and process image
|
82 |
-
image = Image.open("your_image.jpg")
|
83 |
-
inputs = processor(image, return_tensors="pt")
|
84 |
-
|
85 |
-
# Make prediction
|
86 |
-
with torch.no_grad():
|
87 |
-
outputs = model(**inputs)
|
88 |
-
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
89 |
-
prediction = torch.argmax(probabilities, dim=-1).item()
|
90 |
-
|
91 |
-
labels = ["Real", "AI-generated"]
|
92 |
-
confidence = probabilities[0, prediction].item()
|
93 |
-
print(f"Prediction: {labels[prediction]} (Confidence: {confidence:.2%})")
|
94 |
-
```
|
95 |
-
|
96 |
-
### **Option 2: Using the Inference Script**
|
97 |
-
|
98 |
-
```bash
|
99 |
-
# Clone the repository
|
100 |
-
git clone https://huggingface.co/bobbysam/resnet18-image-detector
|
101 |
-
cd resnet18-image-detector
|
102 |
-
|
103 |
-
# Install dependencies
|
104 |
-
pip install -r requirements.txt
|
105 |
-
|
106 |
-
# Run inference
|
107 |
-
python inference.py --image path/to/your/image.jpg --model ./
|
108 |
-
```
|
109 |
-
|
110 |
-
### **Option 3: Using the Custom Wrapper**
|
111 |
-
|
112 |
-
```python
|
113 |
-
from inference import AIImageDetector
|
114 |
-
|
115 |
-
# Initialize detector
|
116 |
-
detector = AIImageDetector()
|
117 |
-
|
118 |
-
# Make prediction
|
119 |
-
result = detector.predict("your_image.jpg")
|
120 |
-
print(f"Prediction: {result['prediction']}")
|
121 |
-
print(f"Confidence: {result['confidence']:.2%}")
|
122 |
-
``` ---
|
123 |
-
|
124 |
-
## ποΈ Training Your Own Model
|
125 |
-
|
126 |
-
### **Quick Training with Hugging Face Trainer**
|
127 |
-
|
128 |
-
```bash
|
129 |
-
# 1. Setup the environment
|
130 |
-
python setup.py
|
131 |
-
|
132 |
-
# 2. Download/prepare your dataset
|
133 |
-
python download_dataset.py --dataset_type custom --source_dir /path/to/your/data
|
134 |
-
|
135 |
-
# 3. Train the model
|
136 |
-
python trainer.py \
|
137 |
-
--data_dir ./data \
|
138 |
-
--output_dir ./results \
|
139 |
-
--num_epochs 10 \
|
140 |
-
--batch_size 16 \
|
141 |
-
--push_to_hub \
|
142 |
-
--hub_model_id your-username/resnet18-detector
|
143 |
-
```
|
144 |
-
|
145 |
-
### **Training Arguments**
|
146 |
-
|
147 |
-
| Argument | Description | Default |
|
148 |
-
|----------|-------------|---------|
|
149 |
-
| `--data_dir` | Path to dataset directory | Required |
|
150 |
-
| `--output_dir` | Output directory for model | `./results` |
|
151 |
-
| `--num_epochs` | Number of training epochs | 10 |
|
152 |
-
| `--batch_size` | Training batch size | 16 |
|
153 |
-
| `--learning_rate` | Learning rate | 2e-5 |
|
154 |
-
| `--dropout_rate` | Dropout rate for regularization | 0.5 |
|
155 |
-
| `--freeze_backbone` | Freeze ResNet backbone | False |
|
156 |
-
| `--push_to_hub` | Push model to HF Hub | False |
|
157 |
-
| `--hub_model_id` | Hugging Face model ID | None |
|
158 |
-
|
159 |
-
### **Dataset Structure**
|
160 |
-
|
161 |
-
Your dataset should be organized as follows:
|
162 |
-
```
|
163 |
-
data/
|
164 |
-
βββ real/
|
165 |
-
β βββ image1.jpg
|
166 |
-
β βββ image2.jpg
|
167 |
-
β βββ ...
|
168 |
-
βββ ai_generated/
|
169 |
-
βββ image1.jpg
|
170 |
-
βββ image2.jpg
|
171 |
-
βββ ...
|
172 |
-
```
|
173 |
-
|
174 |
-
---
|
175 |
-
|
176 |
-
## π Deployment Options
|
177 |
-
|
178 |
-
This model supports multiple deployment options through Hugging Face:
|
179 |
-
|
180 |
-
### **1. Hugging Face Inference Endpoints**
|
181 |
-
- Production-ready inference API
|
182 |
-
- Auto-scaling and load balancing
|
183 |
-
- Pay-per-request pricing
|
184 |
-
|
185 |
-
### **2. Amazon SageMaker**
|
186 |
-
- Deploy directly to AWS SageMaker
|
187 |
-
- Enterprise-grade infrastructure
|
188 |
-
- Custom scaling policies
|
189 |
-
|
190 |
-
### **3. Azure ML**
|
191 |
-
- Deploy to Azure Machine Learning
|
192 |
-
- Integration with Azure services
|
193 |
-
- Enterprise security features
|
194 |
-
|
195 |
-
### **4. Local Deployment**
|
196 |
-
```python
|
197 |
-
# Load model locally
|
198 |
-
from transformers import pipeline
|
199 |
-
|
200 |
-
classifier = pipeline(
|
201 |
-
"image-classification",
|
202 |
-
model="bobbysam/resnet18-image-detector",
|
203 |
-
device=0 if torch.cuda.is_available() else -1
|
204 |
-
)
|
205 |
-
|
206 |
-
result = classifier("path/to/image.jpg")
|
207 |
-
```
|
208 |
-
|
209 |
-
---
|
210 |
-
|
211 |
-
## π₯ Input format and requirements
|
212 |
-
|
213 |
-
- **Input:** RGB image (PIL Image or file path), resized to 224x224, normalized as in ImageNet.
|
214 |
-
- **Output:**
|
215 |
-
- `0` = Real photograph
|
216 |
-
- `1` = AI-generated image
|
217 |
-
|
218 |
-
---
|
219 |
-
|
220 |
-
## π¦ Model details
|
221 |
-
|
222 |
-
- **Architecture:** ResNet18 (PyTorch, torchvision)
|
223 |
-
- **Training data:** Real & AI-generated images (see [ProofGuard project](https://github.com/Proofguard/proofguard-backend))
|
224 |
-
- **Framework:** PyTorch
|
225 |
-
- **Size:** ~60MB
|
226 |
-
|
227 |
-
---
|
228 |
-
|
229 |
-
## βοΈ License and usage
|
230 |
-
|
231 |
-
- **License:** [MIT](https://opensource.org/license/mit/) (or specify your own)
|
232 |
-
- **Usage restrictions:** For research, education, and non-commercial projects.
|
233 |
-
_For commercial use, contact the author or check the ProofGuard project license._
|
234 |
-
|
235 |
-
---
|
236 |
-
|
237 |
-
## π Citation
|
238 |
-
|
239 |
-
If you use this model, please cite:
|
240 |
-
```text
|
241 |
-
ProofGuard: AI Image Authenticity Detection
|
242 |
-
https://github.com/Proofguard/proofguard-backend
|
243 |
-
Model by @bobbysam (Hugging Face)
|
244 |
-
```
|
245 |
-
|
246 |
-
---
|
247 |
-
|
248 |
-
## π οΈ Maintainer
|
249 |
-
|
250 |
-
- [@bobbysam](https://huggingface.co/bobbysam)
|
251 |
-
- [ProofGuard GitHub](https://github.com/Proofguard/proofguard-backend)
|
252 |
-
|
253 |
---
|
254 |
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
2 |
library_name: transformers
|
|
|
3 |
tags:
|
4 |
+
- generated_from_trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
metrics:
|
6 |
- accuracy
|
7 |
+
- f1
|
8 |
- precision
|
9 |
- recall
|
|
|
10 |
model-index:
|
11 |
- name: resnet18-image-detector
|
12 |
+
results: []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
---
|
14 |
|
15 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
16 |
+
should probably proofread and complete it, then remove this comment. -->
|
17 |
+
|
18 |
+
# resnet18-image-detector
|
19 |
+
|
20 |
+
This model is a fine-tuned version of [](https://huggingface.co/) on the None dataset.
|
21 |
+
It achieves the following results on the evaluation set:
|
22 |
+
- Loss: 0.2461
|
23 |
+
- Accuracy: 0.9738
|
24 |
+
- F1: 0.9737
|
25 |
+
- Precision: 0.9739
|
26 |
+
- Recall: 0.9738
|
27 |
+
|
28 |
+
## Model description
|
29 |
+
|
30 |
+
More information needed
|
31 |
+
|
32 |
+
## Intended uses & limitations
|
33 |
+
|
34 |
+
More information needed
|
35 |
+
|
36 |
+
## Training and evaluation data
|
37 |
+
|
38 |
+
More information needed
|
39 |
+
|
40 |
+
## Training procedure
|
41 |
+
|
42 |
+
### Training hyperparameters
|
43 |
+
|
44 |
+
The following hyperparameters were used during training:
|
45 |
+
- learning_rate: 0.0001
|
46 |
+
- train_batch_size: 16
|
47 |
+
- eval_batch_size: 16
|
48 |
+
- seed: 42
|
49 |
+
- gradient_accumulation_steps: 2
|
50 |
+
- total_train_batch_size: 32
|
51 |
+
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
52 |
+
- lr_scheduler_type: cosine_with_restarts
|
53 |
+
- lr_scheduler_warmup_ratio: 0.1
|
54 |
+
- num_epochs: 3
|
55 |
+
|
56 |
+
### Training results
|
57 |
+
|
58 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy | F1 | Precision | Recall |
|
59 |
+
|:-------------:|:------:|:----:|:---------------:|:--------:|:------:|:---------:|:------:|
|
60 |
+
| 1.3887 | 0.0533 | 50 | 0.6371 | 0.7338 | 0.7336 | 0.7345 | 0.7338 |
|
61 |
+
| 1.1433 | 0.1067 | 100 | 0.4604 | 0.8571 | 0.8569 | 0.8591 | 0.8571 |
|
62 |
+
| 0.8563 | 0.16 | 150 | 0.3538 | 0.9081 | 0.9080 | 0.9094 | 0.9081 |
|
63 |
+
| 0.7671 | 0.2133 | 200 | 0.3244 | 0.9277 | 0.9277 | 0.9282 | 0.9277 |
|
64 |
+
| 0.7213 | 0.2667 | 250 | 0.3244 | 0.9301 | 0.9300 | 0.9307 | 0.9301 |
|
65 |
+
| 0.6996 | 0.32 | 300 | 0.3187 | 0.9324 | 0.9323 | 0.9339 | 0.9324 |
|
66 |
+
| 0.6975 | 0.3733 | 350 | 0.3429 | 0.9193 | 0.9189 | 0.9268 | 0.9193 |
|
67 |
+
| 0.7327 | 0.4267 | 400 | 0.2890 | 0.9520 | 0.9520 | 0.9523 | 0.9520 |
|
68 |
+
| 0.7072 | 0.48 | 450 | 0.2939 | 0.9460 | 0.9460 | 0.9475 | 0.9460 |
|
69 |
+
| 0.666 | 0.5333 | 500 | 0.2886 | 0.9506 | 0.9505 | 0.9509 | 0.9506 |
|
70 |
+
| 0.6596 | 0.5867 | 550 | 0.2800 | 0.9543 | 0.9543 | 0.9550 | 0.9543 |
|
71 |
+
| 0.6394 | 0.64 | 600 | 0.2800 | 0.9523 | 0.9522 | 0.9524 | 0.9523 |
|
72 |
+
| 0.6734 | 0.6933 | 650 | 0.2740 | 0.9579 | 0.9579 | 0.9586 | 0.9579 |
|
73 |
+
| 0.6467 | 0.7467 | 700 | 0.2727 | 0.9582 | 0.9582 | 0.9595 | 0.9582 |
|
74 |
+
| 0.6662 | 0.8 | 750 | 0.2711 | 0.9585 | 0.9585 | 0.9586 | 0.9585 |
|
75 |
+
| 0.5994 | 0.8533 | 800 | 0.2625 | 0.9656 | 0.9656 | 0.9656 | 0.9656 |
|
76 |
+
| 0.6189 | 0.9067 | 850 | 0.2843 | 0.95 | 0.9500 | 0.9511 | 0.95 |
|
77 |
+
| 0.6317 | 0.96 | 900 | 0.2600 | 0.9651 | 0.9651 | 0.9658 | 0.9651 |
|
78 |
+
| 0.5973 | 1.0128 | 950 | 0.2497 | 0.9733 | 0.9733 | 0.9733 | 0.9733 |
|
79 |
+
| 0.5592 | 1.0661 | 1000 | 0.2461 | 0.9738 | 0.9737 | 0.9739 | 0.9738 |
|
80 |
+
| 0.6093 | 1.1195 | 1050 | 0.2705 | 0.9567 | 0.9567 | 0.9590 | 0.9567 |
|
81 |
+
| 0.5505 | 1.1728 | 1100 | 0.2465 | 0.9716 | 0.9716 | 0.9718 | 0.9716 |
|
82 |
+
|
83 |
+
|
84 |
+
### Framework versions
|
85 |
+
|
86 |
+
- Transformers 4.54.1
|
87 |
+
- Pytorch 2.7.1+cu126
|
88 |
+
- Datasets 4.0.0
|
89 |
+
- Tokenizers 0.21.4
|
hf_config.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hugging Face compatible configuration for existing Space
|
3 |
+
This extends your existing config without breaking it
|
4 |
+
"""
|
5 |
+
|
6 |
+
try:
|
7 |
+
from transformers import PretrainedConfig
|
8 |
+
TRANSFORMERS_AVAILABLE = True
|
9 |
+
except ImportError:
|
10 |
+
TRANSFORMERS_AVAILABLE = False
|
11 |
+
# Fallback configuration
|
12 |
+
class PretrainedConfig:
|
13 |
+
def __init__(self, **kwargs):
|
14 |
+
for key, value in kwargs.items():
|
15 |
+
setattr(self, key, value)
|
16 |
+
|
17 |
+
|
18 |
+
class HFResNet18DetectorConfig(PretrainedConfig):
|
19 |
+
"""
|
20 |
+
Hugging Face compatible configuration for your existing model
|
21 |
+
Works alongside your existing training config
|
22 |
+
"""
|
23 |
+
|
24 |
+
model_type = "resnet18-detector"
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
num_classes: int = 2,
|
29 |
+
image_size: int = 224,
|
30 |
+
architecture: str = "resnet18",
|
31 |
+
dropout_rate: float = 0.5,
|
32 |
+
freeze_backbone: bool = False,
|
33 |
+
pretrained_weights: str = "IMAGENET1K_V1",
|
34 |
+
label_smoothing: float = 0.1, # Anti-overfitting: Label smoothing
|
35 |
+
weight_decay: float = 0.1, # Anti-overfitting: L2 regularization
|
36 |
+
max_grad_norm: float = 1.0, # Anti-overfitting: Gradient clipping
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Initialize HF compatible config with anti-overfitting parameters
|
41 |
+
"""
|
42 |
+
self.num_classes = num_classes
|
43 |
+
self.image_size = image_size
|
44 |
+
self.architecture = architecture
|
45 |
+
self.dropout_rate = dropout_rate
|
46 |
+
self.freeze_backbone = freeze_backbone
|
47 |
+
self.pretrained_weights = pretrained_weights
|
48 |
+
self.label_smoothing = label_smoothing
|
49 |
+
self.weight_decay = weight_decay
|
50 |
+
self.max_grad_norm = max_grad_norm
|
51 |
+
|
52 |
+
if TRANSFORMERS_AVAILABLE:
|
53 |
+
super().__init__(**kwargs)
|
54 |
+
else:
|
55 |
+
for key, value in kwargs.items():
|
56 |
+
setattr(self, key, value)
|
57 |
+
|
58 |
+
|
59 |
+
# Register for auto-loading if transformers is available
|
60 |
+
if TRANSFORMERS_AVAILABLE:
|
61 |
+
try:
|
62 |
+
HFResNet18DetectorConfig.register_for_auto_class()
|
63 |
+
except:
|
64 |
+
pass
|
hf_model.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hugging Face compatible model wrapper for your existing Space
|
3 |
+
This works alongside your existing model loading without breaking it
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import Optional
|
9 |
+
import sys
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Import transformers components if available
|
13 |
+
try:
|
14 |
+
from transformers import PreTrainedModel
|
15 |
+
from transformers.modeling_outputs import ImageClassifierOutput
|
16 |
+
TRANSFORMERS_AVAILABLE = True
|
17 |
+
except ImportError:
|
18 |
+
TRANSFORMERS_AVAILABLE = False
|
19 |
+
# Fallback classes
|
20 |
+
class PreTrainedModel(nn.Module):
|
21 |
+
def __init__(self, config):
|
22 |
+
super().__init__()
|
23 |
+
self.config = config
|
24 |
+
|
25 |
+
class ImageClassifierOutput:
|
26 |
+
def __init__(self, loss=None, logits=None):
|
27 |
+
self.loss = loss
|
28 |
+
self.logits = logits
|
29 |
+
|
30 |
+
# Import your existing components
|
31 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "training"))
|
32 |
+
|
33 |
+
try:
|
34 |
+
from hf_config import HFResNet18DetectorConfig
|
35 |
+
except ImportError:
|
36 |
+
# Fallback config
|
37 |
+
class HFResNet18DetectorConfig:
|
38 |
+
def __init__(self, num_classes=2, **kwargs):
|
39 |
+
self.num_classes = num_classes
|
40 |
+
for key, value in kwargs.items():
|
41 |
+
setattr(self, key, value)
|
42 |
+
|
43 |
+
|
44 |
+
class HFResNet18Detector(PreTrainedModel):
|
45 |
+
"""
|
46 |
+
Hugging Face compatible wrapper for your existing model
|
47 |
+
This allows your model to work with HF Trainer and ecosystem
|
48 |
+
"""
|
49 |
+
|
50 |
+
config_class = HFResNet18DetectorConfig
|
51 |
+
|
52 |
+
def __init__(self, config: HFResNet18DetectorConfig):
|
53 |
+
super().__init__(config)
|
54 |
+
|
55 |
+
self.num_labels = getattr(config, 'num_classes', 2)
|
56 |
+
self.config = config
|
57 |
+
|
58 |
+
# Try to use your existing model creation logic first
|
59 |
+
try:
|
60 |
+
from training.detection_models import create_model
|
61 |
+
from training.config import get_model_config
|
62 |
+
|
63 |
+
model_config = get_model_config("resnet18")
|
64 |
+
self.backbone = create_model("resnet18", model_config)
|
65 |
+
print("[HF Model] Using existing model creation logic")
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
print(f"[HF Model] Fallback to basic ResNet18: {e}")
|
69 |
+
# Fallback to basic ResNet18
|
70 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
71 |
+
|
72 |
+
weights = ResNet18_Weights.IMAGENET1K_V1
|
73 |
+
self.backbone = resnet18(weights=weights)
|
74 |
+
|
75 |
+
# Replace final layer with enhanced regularization
|
76 |
+
in_features = self.backbone.fc.in_features
|
77 |
+
dropout_rate = getattr(config, 'dropout_rate', 0.5)
|
78 |
+
num_classes = getattr(config, 'num_classes', 2)
|
79 |
+
|
80 |
+
# Multi-layer classification head with stronger regularization
|
81 |
+
self.backbone.fc = nn.Sequential(
|
82 |
+
nn.Dropout(dropout_rate),
|
83 |
+
nn.Linear(in_features, 512),
|
84 |
+
nn.ReLU(),
|
85 |
+
nn.BatchNorm1d(512),
|
86 |
+
nn.Dropout(0.6), # Higher dropout for intermediate layer
|
87 |
+
nn.Linear(512, 256),
|
88 |
+
nn.ReLU(),
|
89 |
+
nn.BatchNorm1d(256),
|
90 |
+
nn.Dropout(0.7), # Even higher dropout near output
|
91 |
+
nn.Linear(256, num_classes)
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(
|
95 |
+
self,
|
96 |
+
pixel_values: Optional[torch.Tensor] = None,
|
97 |
+
labels: Optional[torch.Tensor] = None,
|
98 |
+
return_dict: Optional[bool] = None,
|
99 |
+
**kwargs
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Forward pass compatible with both HF and your existing code
|
103 |
+
"""
|
104 |
+
# Handle both HF format and your existing format
|
105 |
+
if pixel_values is None:
|
106 |
+
raise ValueError("pixel_values must be provided")
|
107 |
+
|
108 |
+
# Forward pass through your existing model
|
109 |
+
logits = self.backbone(pixel_values)
|
110 |
+
|
111 |
+
loss = None
|
112 |
+
if labels is not None:
|
113 |
+
# Ensure labels are properly formatted
|
114 |
+
if isinstance(labels, torch.Tensor):
|
115 |
+
labels = labels.long()
|
116 |
+
else:
|
117 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
118 |
+
|
119 |
+
# Ensure labels are 1D
|
120 |
+
if labels.dim() > 1:
|
121 |
+
labels = labels.squeeze()
|
122 |
+
|
123 |
+
# Use label smoothing to combat overfitting with proper error handling
|
124 |
+
try:
|
125 |
+
label_smoothing = getattr(self.config, 'label_smoothing', 0.1)
|
126 |
+
loss_fct = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
127 |
+
loss = loss_fct(logits, labels)
|
128 |
+
except Exception as e:
|
129 |
+
print(f"[HF Model] Label smoothing failed ({e}), falling back to standard CrossEntropyLoss")
|
130 |
+
# Fallback to standard cross entropy if label smoothing fails
|
131 |
+
loss_fct = nn.CrossEntropyLoss()
|
132 |
+
loss = loss_fct(logits, labels)
|
133 |
+
|
134 |
+
if TRANSFORMERS_AVAILABLE and return_dict:
|
135 |
+
return ImageClassifierOutput(
|
136 |
+
loss=loss,
|
137 |
+
logits=logits,
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
# Fallback for non-HF usage
|
141 |
+
if loss is not None:
|
142 |
+
return loss, logits
|
143 |
+
return logits
|
144 |
+
|
145 |
+
def predict_compatibility(self, x):
|
146 |
+
"""
|
147 |
+
Compatibility method for your existing inference code
|
148 |
+
"""
|
149 |
+
return self.backbone(x)
|
150 |
+
|
151 |
+
|
152 |
+
# Register for auto-loading if transformers is available
|
153 |
+
if TRANSFORMERS_AVAILABLE:
|
154 |
+
try:
|
155 |
+
HFResNet18Detector.register_for_auto_class("AutoModelForImageClassification")
|
156 |
+
except:
|
157 |
+
pass
|
158 |
+
|
159 |
+
|
160 |
+
def create_hf_compatible_model(existing_model_path=None):
|
161 |
+
"""
|
162 |
+
Helper function to create HF compatible model from existing weights
|
163 |
+
"""
|
164 |
+
config = HFResNet18DetectorConfig()
|
165 |
+
model = HFResNet18Detector(config)
|
166 |
+
|
167 |
+
if existing_model_path and os.path.exists(existing_model_path):
|
168 |
+
try:
|
169 |
+
# Load your existing model weights
|
170 |
+
checkpoint = torch.load(existing_model_path, map_location="cpu", weights_only=False)
|
171 |
+
if 'model_state_dict' in checkpoint:
|
172 |
+
model.backbone.load_state_dict(checkpoint['model_state_dict'])
|
173 |
+
else:
|
174 |
+
model.backbone.load_state_dict(checkpoint)
|
175 |
+
print(f"[HF Model] Loaded weights from {existing_model_path}")
|
176 |
+
except Exception as e:
|
177 |
+
print(f"[HF Model] Failed to load weights: {e}")
|
178 |
+
|
179 |
+
return model
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45284592
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d66bdfeedccc35f11a7b63a1c1864eae577135e56fa9cb8e00a828e6ce274d4d
|
3 |
size 45284592
|