""" Training script for ResNet18 AI Image Detector Compatible with Hugging Face Trainer API """ import os import sys import argparse import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image from pathlib import Path from typing import Dict, Any, Optional, Tuple import json import logging from sklearn.metrics import accuracy_score, precision_recall_fscore_support import numpy as np # Hugging Face imports from transformers import ( Trainer, TrainingArguments, AutoImageProcessor, EarlyStoppingCallback ) from datasets import Dataset as HFDataset, DatasetDict # Local imports from config import ResNet18DetectorConfig from detection_models import ResNet18Detector # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ImageDetectionDataset(Dataset): """PyTorch Dataset for image classification""" def __init__(self, image_paths, labels, transform=None): self.image_paths = image_paths self.labels = labels self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] label = self.labels[idx] # Load image try: image = Image.open(image_path).convert('RGB') except Exception as e: logger.error(f"Error loading image {image_path}: {e}") # Return a black image as fallback image = Image.new('RGB', (224, 224), (0, 0, 0)) if self.transform: image = self.transform(image) return { 'pixel_values': image, 'labels': torch.tensor(label, dtype=torch.long) } def get_transforms(): """Get data augmentation transforms""" train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return train_transform, val_transform def load_dataset(data_dir: str, split_ratio: float = 0.8) -> Tuple[ImageDetectionDataset, ImageDetectionDataset]: """Load and split the dataset""" data_path = Path(data_dir) # Assuming directory structure: data_dir/real/, data_dir/ai_generated/ real_dir = data_path / "real" ai_dir = data_path / "ai_generated" if not real_dir.exists() or not ai_dir.exists(): raise ValueError(f"Expected directories 'real' and 'ai_generated' in {data_dir}") # Collect all image paths and labels image_paths = [] labels = [] # Real images (label 0) for img_path in real_dir.glob("*"): if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']: image_paths.append(str(img_path)) labels.append(0) # AI-generated images (label 1) for img_path in ai_dir.glob("*"): if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']: image_paths.append(str(img_path)) labels.append(1) # Shuffle and split indices = np.random.permutation(len(image_paths)) split_idx = int(len(indices) * split_ratio) train_indices = indices[:split_idx] val_indices = indices[split_idx:] train_paths = [image_paths[i] for i in train_indices] train_labels = [labels[i] for i in train_indices] val_paths = [image_paths[i] for i in val_indices] val_labels = [labels[i] for i in val_indices] train_transform, val_transform = get_transforms() train_dataset = ImageDetectionDataset(train_paths, train_labels, train_transform) val_dataset = ImageDetectionDataset(val_paths, val_labels, val_transform) logger.info(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples") return train_dataset, val_dataset def compute_metrics(eval_pred): """Compute evaluation metrics""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted') accuracy = accuracy_score(labels, predictions) return { 'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall } def train_model(args): """Main training function""" # Set random seeds for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) # Load datasets train_dataset, val_dataset = load_dataset(args.data_dir, args.split_ratio) # Initialize configuration config = ResNet18DetectorConfig( num_classes=2, image_size=224, architecture="resnet18", dropout_rate=args.dropout_rate, freeze_backbone=args.freeze_backbone ) # Initialize model model = ResNet18Detector(config) # Training arguments training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.num_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, warmup_steps=args.warmup_steps, weight_decay=args.weight_decay, logging_dir=f"{args.output_dir}/logs", logging_steps=args.logging_steps, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, save_total_limit=2, dataloader_num_workers=args.num_workers, remove_unused_columns=False, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id if args.push_to_hub else None, report_to=["tensorboard"] if args.use_tensorboard else [], ) # Initialize trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] if args.early_stopping else [], ) # Train the model logger.info("Starting training...") trainer.train() # Evaluate the model logger.info("Evaluating model...") eval_results = trainer.evaluate() logger.info(f"Evaluation results: {eval_results}") # Save the model logger.info(f"Saving model to {args.output_dir}") trainer.save_model() # Push to hub if specified if args.push_to_hub: logger.info(f"Pushing model to Hugging Face Hub: {args.hub_model_id}") trainer.push_to_hub() return eval_results def main(): parser = argparse.ArgumentParser(description="Train ResNet18 AI Image Detector") # Data arguments parser.add_argument("--data_dir", type=str, required=True, help="Path to dataset directory") parser.add_argument("--split_ratio", type=float, default=0.8, help="Train/validation split ratio") # Model arguments parser.add_argument("--dropout_rate", type=float, default=0.5, help="Dropout rate") parser.add_argument("--freeze_backbone", action="store_true", help="Freeze backbone weights") # Training arguments parser.add_argument("--output_dir", type=str, default="./results", help="Output directory") parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=16, help="Batch size") parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") parser.add_argument("--warmup_steps", type=int, default=500, help="Warmup steps") parser.add_argument("--logging_steps", type=int, default=10, help="Logging steps") parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers") parser.add_argument("--seed", type=int, default=42, help="Random seed") # Callbacks parser.add_argument("--early_stopping", action="store_true", help="Use early stopping") parser.add_argument("--use_tensorboard", action="store_true", help="Use TensorBoard logging") # Hugging Face Hub parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub") parser.add_argument("--hub_model_id", type=str, help="Hugging Face Hub model ID") args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Train the model results = train_model(args) # Save results with open(f"{args.output_dir}/training_results.json", "w") as f: json.dump(results, f, indent=2) logger.info("Training completed successfully!") if __name__ == "__main__": main()