Spaces:
Runtime error
Runtime error
Commit
·
b84617e
1
Parent(s):
123ce5d
fixed app.py
Browse files- app.py +51 -42
- requirements.txt +4 -3
app.py
CHANGED
@@ -4,62 +4,71 @@ from PIL import Image
|
|
4 |
import gradio as gr
|
5 |
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
6 |
import gdown
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
-
# Device configuration
|
14 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
-
|
16 |
-
# Load model
|
17 |
-
MODEL_NAME = "nvidia/segformer-b0-finetuned-ade-512-512" # Thay bằng tên model của bạn
|
18 |
-
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
|
19 |
-
model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1) # Điều chỉnh output lớp classifier
|
20 |
-
model = model.to(device)
|
21 |
-
|
22 |
-
# Load checkpoint
|
23 |
-
checkpoint = torch.load(output, map_location=device)
|
24 |
-
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
25 |
-
model.eval()
|
26 |
-
|
27 |
-
# Image processor
|
28 |
-
image_processor = SegformerImageProcessor()
|
29 |
-
|
30 |
-
# Inference function
|
31 |
def predict_segmentation(image):
|
32 |
"""
|
33 |
Predict segmentation mask for input image.
|
34 |
"""
|
35 |
-
raw_image =
|
36 |
inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
|
37 |
-
H, W = raw_image.
|
38 |
|
39 |
with torch.no_grad():
|
40 |
outputs = model(**inputs)
|
41 |
logits = outputs.logits
|
42 |
-
upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W)
|
43 |
predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
|
|
|
|
|
44 |
|
45 |
-
return predictions
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
""
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
#
|
56 |
-
|
57 |
-
fn=gradio_interface,
|
58 |
-
inputs=gr.Image(type="pil"), # Đầu vào: ảnh PIL
|
59 |
-
outputs="image", # Đầu ra: ảnh segmentation mask
|
60 |
-
title="Segmentation with Segformer",
|
61 |
-
description="Upload an image to generate a segmentation mask."
|
62 |
-
)
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
iface.launch()
|
|
|
4 |
import gradio as gr
|
5 |
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
6 |
import gdown
|
7 |
+
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import cv2
|
10 |
|
11 |
+
def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3):
|
12 |
+
colored_mask = np.zeros_like(image)
|
13 |
+
colored_mask[mask > 0] = mask_color
|
14 |
+
overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
|
15 |
+
return overlay_image
|
16 |
|
17 |
+
# Hàm dự đoán segmentation mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def predict_segmentation(image):
|
19 |
"""
|
20 |
Predict segmentation mask for input image.
|
21 |
"""
|
22 |
+
raw_image = np.array(image)
|
23 |
inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
|
24 |
+
H, W = raw_image.shape[0], raw_image.shape[1]
|
25 |
|
26 |
with torch.no_grad():
|
27 |
outputs = model(**inputs)
|
28 |
logits = outputs.logits
|
29 |
+
upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W))
|
30 |
predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
|
31 |
+
overlay = overlay_mask(raw_image,predictions)
|
32 |
+
return overlay
|
33 |
|
|
|
34 |
|
35 |
+
if __name__ == '__main__':
|
36 |
+
# Tải file checkpoint nếu chưa tồn tại
|
37 |
+
url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570"
|
38 |
+
output = "Segformer_ISIC2018_epoch_50_model.pth"
|
39 |
+
if not os.path.exists(output):
|
40 |
+
gdown.download(url, output, quiet=False)
|
41 |
+
|
42 |
+
# Thiết lập thiết bị
|
43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
|
45 |
+
# Load model từ HuggingFace
|
46 |
+
MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640"
|
47 |
+
try:
|
48 |
+
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
|
49 |
+
except EnvironmentError as e:
|
50 |
+
print(f"Lỗi khi tải model từ HuggingFace: {e}")
|
51 |
+
exit()
|
52 |
+
|
53 |
+
# Điều chỉnh và tải checkpoint
|
54 |
+
model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1)
|
55 |
+
model = model.to(device)
|
56 |
+
model = torch.nn.DataParallel(model)
|
57 |
+
|
58 |
+
# Load checkpoint
|
59 |
+
checkpoint = torch.load(output, map_location=device,weights_only=True)
|
60 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
61 |
+
model.eval()
|
62 |
|
63 |
+
# Image processor
|
64 |
+
image_processor = SegformerImageProcessor()
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Gradio app
|
67 |
+
iface = gr.Interface(
|
68 |
+
fn=predict_segmentation, # Gọi hàm dự đoán
|
69 |
+
inputs=gr.Image(type="pil"),
|
70 |
+
outputs="image",
|
71 |
+
title="Segmentation with Segformer",
|
72 |
+
description="Upload an image to generate a segmentation mask."
|
73 |
+
)
|
74 |
iface.launch()
|
requirements.txt
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
torch
|
2 |
-
torchvision
|
3 |
-
transformers
|
4 |
numpy
|
5 |
pillow
|
6 |
gradio
|
7 |
-
|
|
|
|
|
|
|
|
1 |
torch
|
|
|
|
|
2 |
numpy
|
3 |
pillow
|
4 |
gradio
|
5 |
+
transformers
|
6 |
+
gdown
|
7 |
+
matplotlib
|
8 |
+
opencv-python
|