Rohan Kumar Shah commited on
Commit
b7c5baf
·
1 Parent(s): bfb2e8a

added real and forgery detection model

Browse files
features/real_forged_classifier/controller.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import IO
2
+ from preprocessor import preprocessor
3
+ from inferencer import interferencer
4
+
5
+ class ClassificationController:
6
+ """
7
+ Controller to handle the image classification logic.
8
+ """
9
+ def classify_image(self, image_file: IO) -> dict:
10
+ """
11
+ Orchestrates the classification of a single image file.
12
+
13
+ Args:
14
+ image_file (IO): The image file to classify.
15
+
16
+ Returns:
17
+ dict: The classification result.
18
+ """
19
+ try:
20
+ # Step 1: Preprocess the image
21
+ image_tensor = preprocessor.process(image_file)
22
+
23
+ # Step 2: Perform inference
24
+ result = interferencer.predict(image_tensor)
25
+
26
+ return result
27
+ except ValueError as e:
28
+ # Handle specific errors like invalid images
29
+ return {"error": str(e)}
30
+ except Exception as e:
31
+ # Handle unexpected errors
32
+ print(f"An unexpected error occurred: {e}")
33
+ return {"error": "An internal error occurred during classification."}
34
+
35
+ # Create a single instance of the controller
36
+ controller = ClassificationController()
features/real_forged_classifier/inferencer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+ # Import the globally loaded models instance
6
+ from model_loader import models
7
+
8
+ class Interferencer:
9
+ """
10
+ Performs inference using the FFT CNN model.
11
+ """
12
+ def __init__(self):
13
+ """
14
+ Initializes the interferencer with the loaded model.
15
+ """
16
+ self.fft_model = models.fft_model
17
+
18
+ @torch.no_grad()
19
+ def predict(self, image_tensor: torch.Tensor) -> dict:
20
+ """
21
+ Takes a preprocessed image tensor and returns the classification result.
22
+
23
+ Args:
24
+ image_tensor (torch.Tensor): The preprocessed image tensor.
25
+
26
+ Returns:
27
+ dict: A dictionary containing the classification label and confidence score.
28
+ """
29
+ # 1. Get model outputs (logits)
30
+ outputs = self.fft_model(image_tensor)
31
+
32
+ # 2. Apply softmax to get probabilities
33
+ probabilities = F.softmax(outputs, dim=1)
34
+
35
+ # 3. Get the confidence and the predicted class index
36
+ confidence, predicted_idx = torch.max(probabilities, 1)
37
+
38
+ prediction = predicted_idx.item()
39
+
40
+ # 4. Map the prediction to a human-readable label
41
+ # Ensure this mapping matches the labels used during training
42
+ # Typically: 0 -> fake, 1 -> real
43
+ label_map = {0: 'fake', 1: 'real'}
44
+ classification_label = label_map.get(prediction, "unknown")
45
+
46
+ return {
47
+ "classification": classification_label,
48
+ "confidence": confidence.item()
49
+ }
50
+
51
+ # Create a single instance of the interferencer
52
+ interferencer = Interferencer()
features/real_forged_classifier/main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from routes import router as api_router
3
+
4
+ # Initialize the FastAPI app
5
+ app = FastAPI(
6
+ title="Real vs. Fake Image Classification API",
7
+ description="An API to classify images as real or forged using FFT and cnn.",
8
+ version="1.0.0"
9
+ )
10
+
11
+ # Include the API router
12
+ # All routes defined in routes.py will be available under the /api prefix
13
+ app.include_router(api_router, prefix="/api", tags=["Classification"])
14
+
15
+ @app.get("/", tags=["Root"])
16
+ async def read_root():
17
+ """
18
+ A simple root endpoint to confirm the API is running.
19
+ """
20
+ return {"message": "Welcome to the Image Classification API. Go to /docs for the API documentation."}
21
+
22
+ # To run this application:
23
+ # 1. Make sure you have all dependencies from requirements.txt installed.
24
+ # 2. Make sure the 'svm_model.joblib' file is in the same directory.
25
+ # 3. Run the following command in your terminal:
26
+ # uvicorn main:app --reload
features/real_forged_classifier/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class FFTCNN(nn.Module):
6
+ """
7
+ Defines the Convolutional Neural Network architecture.
8
+ This structure must match the model that was trained and saved.
9
+ """
10
+ def __init__(self):
11
+ super(FFTCNN, self).__init__()
12
+ # Ensure 'self.' is used here to define the layers as instance attributes
13
+ self.conv_layers = nn.Sequential(
14
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
15
+ nn.ReLU(),
16
+ nn.MaxPool2d(kernel_size=2, stride=2),
17
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2, stride=2)
20
+ )
21
+
22
+ # Ensure 'self.' is used here as well
23
+ self.fc_layers = nn.Sequential(
24
+ nn.Linear(32 * 56 * 56, 128), # This size depends on your 224x224 input
25
+ nn.ReLU(),
26
+ nn.Linear(128, 2) # 2 output classes
27
+ )
28
+
29
+ def forward(self, x):
30
+ # Now, 'self.conv_layers' can be found because it was defined correctly
31
+ x = self.conv_layers(x)
32
+ x = x.view(x.size(0), -1) # Flatten the feature maps
33
+ x = self.fc_layers(x)
34
+ return x
features/real_forged_classifier/model_loader.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from huggingface_hub import hf_hub_download
4
+ from model import FFTCNN # Import the model architecture
5
+
6
+ class ModelLoader:
7
+ """
8
+ A class to load and hold the PyTorch CNN model.
9
+ """
10
+ def __init__(self, model_repo_id: str, model_filename: str):
11
+ """
12
+ Initializes the ModelLoader and loads the model.
13
+
14
+ Args:
15
+ model_repo_id (str): The repository ID on Hugging Face.
16
+ model_filename (str): The name of the model file (.pth) in the repository.
17
+ """
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Using device: {self.device}")
20
+
21
+ self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
22
+ print("FFT CNN model loaded successfully.")
23
+
24
+ def _load_fft_model(self, repo_id: str, filename: str):
25
+ """
26
+ Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
27
+
28
+ Args:
29
+ repo_id (str): The repository ID on Hugging Face.
30
+ filename (str): The name of the model file (.pth) in the repository.
31
+
32
+ Returns:
33
+ The loaded PyTorch model object.
34
+ """
35
+ print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
36
+ try:
37
+ # Download the model file from the Hub. It returns the cached path.
38
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
39
+ print(f"Model downloaded to: {model_path}")
40
+
41
+ # Initialize the model architecture
42
+ model = FFTCNN()
43
+
44
+ # Load the saved weights (state_dict) into the model
45
+ model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
46
+
47
+ # Set the model to evaluation mode
48
+ model.to(self.device)
49
+ model.eval()
50
+
51
+ return model
52
+ except Exception as e:
53
+ print(f"Error downloading or loading model from Hugging Face: {e}")
54
+ raise
55
+
56
+ # --- Global Model Instance ---
57
+ MODEL_REPO_ID = 'rhnsa/real_forged_classifier'
58
+ MODEL_FILENAME = 'fft_cnn_model_78.pth'
59
+ models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)
60
+
features/real_forged_classifier/preprocessor.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ from typing import IO
5
+ import cv2
6
+ from torchvision import transforms
7
+
8
+ # Import the globally loaded models instance
9
+ from model_loader import models
10
+
11
+ class ImagePreprocessor:
12
+ """
13
+ Handles preprocessing of images for the FFT CNN model.
14
+ """
15
+ def __init__(self):
16
+ """
17
+ Initializes the preprocessor.
18
+ """
19
+ self.device = models.device
20
+ # Define the image transformations, matching the training process
21
+ self.transform = transforms.Compose([
22
+ transforms.ToPILImage(),
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ ])
26
+
27
+ def process(self, image_file: IO) -> torch.Tensor:
28
+ """
29
+ Opens an image file, applies FFT, preprocesses it, and returns a tensor.
30
+
31
+ Args:
32
+ image_file (IO): The image file object (e.g., from a file upload).
33
+
34
+ Returns:
35
+ torch.Tensor: The preprocessed image as a tensor, ready for the model.
36
+ """
37
+ try:
38
+ # Read the image file into a numpy array
39
+ image_np = np.frombuffer(image_file.read(), np.uint8)
40
+ # Decode the image as grayscale
41
+ img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
42
+ except Exception as e:
43
+ print(f"Error reading or decoding image: {e}")
44
+ raise ValueError("Invalid or corrupted image file.")
45
+
46
+ if img is None:
47
+ raise ValueError("Could not decode image. File may be empty or corrupted.")
48
+
49
+ # 1. Apply Fast Fourier Transform (FFT)
50
+ f = np.fft.fft2(img)
51
+ fshift = np.fft.fftshift(f)
52
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
53
+
54
+ # Normalize the magnitude spectrum to be in the range [0, 255]
55
+ magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
56
+ magnitude_spectrum = np.uint8(magnitude_spectrum)
57
+
58
+ # 2. Apply torchvision transforms
59
+ image_tensor = self.transform(magnitude_spectrum)
60
+
61
+ # Add a batch dimension and move to the correct device
62
+ image_tensor = image_tensor.unsqueeze(0).to(self.device)
63
+
64
+ return image_tensor
65
+
66
+ # Create a single instance of the preprocessor
67
+ preprocessor = ImagePreprocessor()
features/real_forged_classifier/routes.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, UploadFile, HTTPException, status
2
+ from fastapi.responses import JSONResponse
3
+
4
+ # Import the controller instance
5
+ from controller import controller
6
+
7
+ # Create an API router
8
+ router = APIRouter()
9
+
10
+ @router.post("/classify_forgery", summary="Classify an image as Real or Fake")
11
+ async def classify_image_endpoint(image: UploadFile = File(...)):
12
+ """
13
+ Accepts an image file and classifies it as 'real' or 'fake'.
14
+
15
+ - **image**: The image file to be classified (e.g., JPEG, PNG).
16
+
17
+ Returns a JSON object with the classification and a confidence score.
18
+ """
19
+ # Check for a valid image content type
20
+ if not image.content_type.startswith("image/"):
21
+ raise HTTPException(
22
+ status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
23
+ detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
24
+ )
25
+
26
+ # The controller expects a file-like object, which `image.file` provides
27
+ result = controller.classify_image(image.file)
28
+
29
+ if "error" in result:
30
+ # If the controller returned an error, forward it as an HTTP exception
31
+ raise HTTPException(
32
+ status_code=status.HTTP_400_BAD_REQUEST,
33
+ detail=result["error"]
34
+ )
35
+
36
+ return JSONResponse(content=result, status_code=status.HTTP_200_OK)
37
+