Spaces:
Sleeping
Sleeping
Rohan Kumar Shah
commited on
Commit
·
b7c5baf
1
Parent(s):
bfb2e8a
added real and forgery detection model
Browse files- features/real_forged_classifier/controller.py +36 -0
- features/real_forged_classifier/inferencer.py +52 -0
- features/real_forged_classifier/main.py +26 -0
- features/real_forged_classifier/model.py +34 -0
- features/real_forged_classifier/model_loader.py +60 -0
- features/real_forged_classifier/preprocessor.py +67 -0
- features/real_forged_classifier/routes.py +37 -0
features/real_forged_classifier/controller.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import IO
|
2 |
+
from preprocessor import preprocessor
|
3 |
+
from inferencer import interferencer
|
4 |
+
|
5 |
+
class ClassificationController:
|
6 |
+
"""
|
7 |
+
Controller to handle the image classification logic.
|
8 |
+
"""
|
9 |
+
def classify_image(self, image_file: IO) -> dict:
|
10 |
+
"""
|
11 |
+
Orchestrates the classification of a single image file.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
image_file (IO): The image file to classify.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
dict: The classification result.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
# Step 1: Preprocess the image
|
21 |
+
image_tensor = preprocessor.process(image_file)
|
22 |
+
|
23 |
+
# Step 2: Perform inference
|
24 |
+
result = interferencer.predict(image_tensor)
|
25 |
+
|
26 |
+
return result
|
27 |
+
except ValueError as e:
|
28 |
+
# Handle specific errors like invalid images
|
29 |
+
return {"error": str(e)}
|
30 |
+
except Exception as e:
|
31 |
+
# Handle unexpected errors
|
32 |
+
print(f"An unexpected error occurred: {e}")
|
33 |
+
return {"error": "An internal error occurred during classification."}
|
34 |
+
|
35 |
+
# Create a single instance of the controller
|
36 |
+
controller = ClassificationController()
|
features/real_forged_classifier/inferencer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# Import the globally loaded models instance
|
6 |
+
from model_loader import models
|
7 |
+
|
8 |
+
class Interferencer:
|
9 |
+
"""
|
10 |
+
Performs inference using the FFT CNN model.
|
11 |
+
"""
|
12 |
+
def __init__(self):
|
13 |
+
"""
|
14 |
+
Initializes the interferencer with the loaded model.
|
15 |
+
"""
|
16 |
+
self.fft_model = models.fft_model
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def predict(self, image_tensor: torch.Tensor) -> dict:
|
20 |
+
"""
|
21 |
+
Takes a preprocessed image tensor and returns the classification result.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
image_tensor (torch.Tensor): The preprocessed image tensor.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
dict: A dictionary containing the classification label and confidence score.
|
28 |
+
"""
|
29 |
+
# 1. Get model outputs (logits)
|
30 |
+
outputs = self.fft_model(image_tensor)
|
31 |
+
|
32 |
+
# 2. Apply softmax to get probabilities
|
33 |
+
probabilities = F.softmax(outputs, dim=1)
|
34 |
+
|
35 |
+
# 3. Get the confidence and the predicted class index
|
36 |
+
confidence, predicted_idx = torch.max(probabilities, 1)
|
37 |
+
|
38 |
+
prediction = predicted_idx.item()
|
39 |
+
|
40 |
+
# 4. Map the prediction to a human-readable label
|
41 |
+
# Ensure this mapping matches the labels used during training
|
42 |
+
# Typically: 0 -> fake, 1 -> real
|
43 |
+
label_map = {0: 'fake', 1: 'real'}
|
44 |
+
classification_label = label_map.get(prediction, "unknown")
|
45 |
+
|
46 |
+
return {
|
47 |
+
"classification": classification_label,
|
48 |
+
"confidence": confidence.item()
|
49 |
+
}
|
50 |
+
|
51 |
+
# Create a single instance of the interferencer
|
52 |
+
interferencer = Interferencer()
|
features/real_forged_classifier/main.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from routes import router as api_router
|
3 |
+
|
4 |
+
# Initialize the FastAPI app
|
5 |
+
app = FastAPI(
|
6 |
+
title="Real vs. Fake Image Classification API",
|
7 |
+
description="An API to classify images as real or forged using FFT and cnn.",
|
8 |
+
version="1.0.0"
|
9 |
+
)
|
10 |
+
|
11 |
+
# Include the API router
|
12 |
+
# All routes defined in routes.py will be available under the /api prefix
|
13 |
+
app.include_router(api_router, prefix="/api", tags=["Classification"])
|
14 |
+
|
15 |
+
@app.get("/", tags=["Root"])
|
16 |
+
async def read_root():
|
17 |
+
"""
|
18 |
+
A simple root endpoint to confirm the API is running.
|
19 |
+
"""
|
20 |
+
return {"message": "Welcome to the Image Classification API. Go to /docs for the API documentation."}
|
21 |
+
|
22 |
+
# To run this application:
|
23 |
+
# 1. Make sure you have all dependencies from requirements.txt installed.
|
24 |
+
# 2. Make sure the 'svm_model.joblib' file is in the same directory.
|
25 |
+
# 3. Run the following command in your terminal:
|
26 |
+
# uvicorn main:app --reload
|
features/real_forged_classifier/model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class FFTCNN(nn.Module):
|
6 |
+
"""
|
7 |
+
Defines the Convolutional Neural Network architecture.
|
8 |
+
This structure must match the model that was trained and saved.
|
9 |
+
"""
|
10 |
+
def __init__(self):
|
11 |
+
super(FFTCNN, self).__init__()
|
12 |
+
# Ensure 'self.' is used here to define the layers as instance attributes
|
13 |
+
self.conv_layers = nn.Sequential(
|
14 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
17 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
20 |
+
)
|
21 |
+
|
22 |
+
# Ensure 'self.' is used here as well
|
23 |
+
self.fc_layers = nn.Sequential(
|
24 |
+
nn.Linear(32 * 56 * 56, 128), # This size depends on your 224x224 input
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Linear(128, 2) # 2 output classes
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
# Now, 'self.conv_layers' can be found because it was defined correctly
|
31 |
+
x = self.conv_layers(x)
|
32 |
+
x = x.view(x.size(0), -1) # Flatten the feature maps
|
33 |
+
x = self.fc_layers(x)
|
34 |
+
return x
|
features/real_forged_classifier/model_loader.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from model import FFTCNN # Import the model architecture
|
5 |
+
|
6 |
+
class ModelLoader:
|
7 |
+
"""
|
8 |
+
A class to load and hold the PyTorch CNN model.
|
9 |
+
"""
|
10 |
+
def __init__(self, model_repo_id: str, model_filename: str):
|
11 |
+
"""
|
12 |
+
Initializes the ModelLoader and loads the model.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_repo_id (str): The repository ID on Hugging Face.
|
16 |
+
model_filename (str): The name of the model file (.pth) in the repository.
|
17 |
+
"""
|
18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
print(f"Using device: {self.device}")
|
20 |
+
|
21 |
+
self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
|
22 |
+
print("FFT CNN model loaded successfully.")
|
23 |
+
|
24 |
+
def _load_fft_model(self, repo_id: str, filename: str):
|
25 |
+
"""
|
26 |
+
Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
repo_id (str): The repository ID on Hugging Face.
|
30 |
+
filename (str): The name of the model file (.pth) in the repository.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
The loaded PyTorch model object.
|
34 |
+
"""
|
35 |
+
print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
|
36 |
+
try:
|
37 |
+
# Download the model file from the Hub. It returns the cached path.
|
38 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
39 |
+
print(f"Model downloaded to: {model_path}")
|
40 |
+
|
41 |
+
# Initialize the model architecture
|
42 |
+
model = FFTCNN()
|
43 |
+
|
44 |
+
# Load the saved weights (state_dict) into the model
|
45 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
|
46 |
+
|
47 |
+
# Set the model to evaluation mode
|
48 |
+
model.to(self.device)
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
return model
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error downloading or loading model from Hugging Face: {e}")
|
54 |
+
raise
|
55 |
+
|
56 |
+
# --- Global Model Instance ---
|
57 |
+
MODEL_REPO_ID = 'rhnsa/real_forged_classifier'
|
58 |
+
MODEL_FILENAME = 'fft_cnn_model_78.pth'
|
59 |
+
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)
|
60 |
+
|
features/real_forged_classifier/preprocessor.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import IO
|
5 |
+
import cv2
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
# Import the globally loaded models instance
|
9 |
+
from model_loader import models
|
10 |
+
|
11 |
+
class ImagePreprocessor:
|
12 |
+
"""
|
13 |
+
Handles preprocessing of images for the FFT CNN model.
|
14 |
+
"""
|
15 |
+
def __init__(self):
|
16 |
+
"""
|
17 |
+
Initializes the preprocessor.
|
18 |
+
"""
|
19 |
+
self.device = models.device
|
20 |
+
# Define the image transformations, matching the training process
|
21 |
+
self.transform = transforms.Compose([
|
22 |
+
transforms.ToPILImage(),
|
23 |
+
transforms.Resize((224, 224)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
])
|
26 |
+
|
27 |
+
def process(self, image_file: IO) -> torch.Tensor:
|
28 |
+
"""
|
29 |
+
Opens an image file, applies FFT, preprocesses it, and returns a tensor.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
image_file (IO): The image file object (e.g., from a file upload).
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
torch.Tensor: The preprocessed image as a tensor, ready for the model.
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
# Read the image file into a numpy array
|
39 |
+
image_np = np.frombuffer(image_file.read(), np.uint8)
|
40 |
+
# Decode the image as grayscale
|
41 |
+
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error reading or decoding image: {e}")
|
44 |
+
raise ValueError("Invalid or corrupted image file.")
|
45 |
+
|
46 |
+
if img is None:
|
47 |
+
raise ValueError("Could not decode image. File may be empty or corrupted.")
|
48 |
+
|
49 |
+
# 1. Apply Fast Fourier Transform (FFT)
|
50 |
+
f = np.fft.fft2(img)
|
51 |
+
fshift = np.fft.fftshift(f)
|
52 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
53 |
+
|
54 |
+
# Normalize the magnitude spectrum to be in the range [0, 255]
|
55 |
+
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
|
56 |
+
magnitude_spectrum = np.uint8(magnitude_spectrum)
|
57 |
+
|
58 |
+
# 2. Apply torchvision transforms
|
59 |
+
image_tensor = self.transform(magnitude_spectrum)
|
60 |
+
|
61 |
+
# Add a batch dimension and move to the correct device
|
62 |
+
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
63 |
+
|
64 |
+
return image_tensor
|
65 |
+
|
66 |
+
# Create a single instance of the preprocessor
|
67 |
+
preprocessor = ImagePreprocessor()
|
features/real_forged_classifier/routes.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, status
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
|
4 |
+
# Import the controller instance
|
5 |
+
from controller import controller
|
6 |
+
|
7 |
+
# Create an API router
|
8 |
+
router = APIRouter()
|
9 |
+
|
10 |
+
@router.post("/classify_forgery", summary="Classify an image as Real or Fake")
|
11 |
+
async def classify_image_endpoint(image: UploadFile = File(...)):
|
12 |
+
"""
|
13 |
+
Accepts an image file and classifies it as 'real' or 'fake'.
|
14 |
+
|
15 |
+
- **image**: The image file to be classified (e.g., JPEG, PNG).
|
16 |
+
|
17 |
+
Returns a JSON object with the classification and a confidence score.
|
18 |
+
"""
|
19 |
+
# Check for a valid image content type
|
20 |
+
if not image.content_type.startswith("image/"):
|
21 |
+
raise HTTPException(
|
22 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
23 |
+
detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
|
24 |
+
)
|
25 |
+
|
26 |
+
# The controller expects a file-like object, which `image.file` provides
|
27 |
+
result = controller.classify_image(image.file)
|
28 |
+
|
29 |
+
if "error" in result:
|
30 |
+
# If the controller returned an error, forward it as an HTTP exception
|
31 |
+
raise HTTPException(
|
32 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
33 |
+
detail=result["error"]
|
34 |
+
)
|
35 |
+
|
36 |
+
return JSONResponse(content=result, status_code=status.HTTP_200_OK)
|
37 |
+
|