NhatNam214 commited on
Commit
b84617e
·
1 Parent(s): 123ce5d

fixed app.py

Browse files
Files changed (2) hide show
  1. app.py +51 -42
  2. requirements.txt +4 -3
app.py CHANGED
@@ -4,62 +4,71 @@ from PIL import Image
4
  import gradio as gr
5
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
6
  import gdown
 
 
 
7
 
8
- url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW"
9
- output = "Segformer_ISIC2018_epoch_50_model.pth"
10
- gdown.download(url, output, quiet=False)
 
 
11
 
12
- # Load checkpoint
13
- # Device configuration
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- # Load model
17
- MODEL_NAME = "nvidia/segformer-b0-finetuned-ade-512-512" # Thay bằng tên model của bạn
18
- model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
19
- model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1) # Điều chỉnh output lớp classifier
20
- model = model.to(device)
21
-
22
- # Load checkpoint
23
- checkpoint = torch.load(output, map_location=device)
24
- model.load_state_dict(checkpoint['model_state_dict'], strict=False)
25
- model.eval()
26
-
27
- # Image processor
28
- image_processor = SegformerImageProcessor()
29
-
30
- # Inference function
31
  def predict_segmentation(image):
32
  """
33
  Predict segmentation mask for input image.
34
  """
35
- raw_image = Image.fromarray(np.array(image)) # Đảm bảo ảnh là định dạng PIL
36
  inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
37
- H, W = raw_image.size[1], raw_image.size[0]
38
 
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
  logits = outputs.logits
42
- upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
43
  predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
 
 
44
 
45
- return predictions
46
 
47
- # Gradio interface
48
- def gradio_interface(image):
49
- """
50
- Gradio input-output interface for model prediction.
51
- """
52
- segmentation_mask = predict_segmentation(image)
53
- return Image.fromarray((segmentation_mask * 255).astype(np.uint8)) # Trả về mask dưới dạng ảnh nhị phân
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Gradio app
56
- iface = gr.Interface(
57
- fn=gradio_interface,
58
- inputs=gr.Image(type="pil"), # Đầu vào: ảnh PIL
59
- outputs="image", # Đầu ra: ảnh segmentation mask
60
- title="Segmentation with Segformer",
61
- description="Upload an image to generate a segmentation mask."
62
- )
63
 
64
- if __name__ == "__main__":
 
 
 
 
 
 
 
65
  iface.launch()
 
4
  import gradio as gr
5
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
6
  import gdown
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
 
11
+ def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3):
12
+ colored_mask = np.zeros_like(image)
13
+ colored_mask[mask > 0] = mask_color
14
+ overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
15
+ return overlay_image
16
 
17
+ # Hàm dự đoán segmentation mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def predict_segmentation(image):
19
  """
20
  Predict segmentation mask for input image.
21
  """
22
+ raw_image = np.array(image)
23
  inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
24
+ H, W = raw_image.shape[0], raw_image.shape[1]
25
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  logits = outputs.logits
29
+ upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W))
30
  predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
31
+ overlay = overlay_mask(raw_image,predictions)
32
+ return overlay
33
 
 
34
 
35
+ if __name__ == '__main__':
36
+ # Tải file checkpoint nếu chưa tồn tại
37
+ url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570"
38
+ output = "Segformer_ISIC2018_epoch_50_model.pth"
39
+ if not os.path.exists(output):
40
+ gdown.download(url, output, quiet=False)
41
+
42
+ # Thiết lập thiết bị
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # Load model từ HuggingFace
46
+ MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640"
47
+ try:
48
+ model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
49
+ except EnvironmentError as e:
50
+ print(f"Lỗi khi tải model từ HuggingFace: {e}")
51
+ exit()
52
+
53
+ # Điều chỉnh và tải checkpoint
54
+ model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1)
55
+ model = model.to(device)
56
+ model = torch.nn.DataParallel(model)
57
+
58
+ # Load checkpoint
59
+ checkpoint = torch.load(output, map_location=device,weights_only=True)
60
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
61
+ model.eval()
62
 
63
+ # Image processor
64
+ image_processor = SegformerImageProcessor()
 
 
 
 
 
 
65
 
66
+ # Gradio app
67
+ iface = gr.Interface(
68
+ fn=predict_segmentation, # Gọi hàm dự đoán
69
+ inputs=gr.Image(type="pil"),
70
+ outputs="image",
71
+ title="Segmentation with Segformer",
72
+ description="Upload an image to generate a segmentation mask."
73
+ )
74
  iface.launch()
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
  torch
2
- torchvision
3
- transformers
4
  numpy
5
  pillow
6
  gradio
7
- gdown
 
 
 
 
1
  torch
 
 
2
  numpy
3
  pillow
4
  gradio
5
+ transformers
6
+ gdown
7
+ matplotlib
8
+ opencv-python