File size: 3,162 Bytes
7fbd2a1 e3c2712 7fbd2a1 c54815f 7fbd2a1 e3c2712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
---
license: apache-2.0
datasets:
- SilpaCS/Augmented_alzheimer
language:
- en
base_model:
- google/siglip2-base-patch16-224
pipeline_tag: image-classification
library_name: transformers
tags:
- Alzheimer
- Stage-Classifier
- SigLIP2
---

# **Alzheimer-Stage-Classifier**
> **Alzheimer-Stage-Classifier** is a multi-class image classification model based on `google/siglip2-base-patch16-224`, designed to identify stages of Alzheimer’s disease from medical imaging data. This tool can assist in **clinical decision support**, **early diagnosis**, and **disease progression tracking**.
```py
Classification Report:
precision recall f1-score support
MildDemented 0.9634 0.9860 0.9746 8960
ModerateDemented 1.0000 1.0000 1.0000 6464
NonDemented 0.8920 0.8910 0.8915 9600
VeryMildDemented 0.8904 0.8704 0.8803 8960
accuracy 0.9314 33984
macro avg 0.9364 0.9369 0.9366 33984
weighted avg 0.9309 0.9314 0.9311 33984
```

---
## **Label Classes**
The model classifies input images into the following stages of Alzheimer’s disease:
```
0: MildDemented
1: ModerateDemented
2: NonDemented
3: VeryMildDemented
```
---
## **Installation**
```bash
pip install transformers torch pillow gradio
```
---
## **Example Inference Code**
```python
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch
# Load model and processor
model_name = "prithivMLmods/Alzheimer-Stage-Classifier"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# ID to label mapping
id2label = {
"0": "MildDemented",
"1": "ModerateDemented",
"2": "NonDemented",
"3": "VeryMildDemented"
}
def classify_alzheimer_stage(image):
image = Image.fromarray(image).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
prediction = {id2label[str(i)]: round(probs[i], 3) for i in range(len(probs))}
return prediction
# Gradio Interface
iface = gr.Interface(
fn=classify_alzheimer_stage,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=4, label="Alzheimer Stage"),
title="Alzheimer-Stage-Classifier",
description="Upload a brain scan image to classify the stage of Alzheimer's: NonDemented, VeryMildDemented, MildDemented, or ModerateDemented."
)
if __name__ == "__main__":
iface.launch()
```
---
## **Applications**
* **Early Alzheimer’s Screening**
* **Clinical Diagnosis Support**
* **Longitudinal Study & Disease Monitoring**
* **Research on Cognitive Decline** |