udman99 commited on
Commit
d36c36e
·
verified ·
1 Parent(s): 87fed53

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -7
handler.py CHANGED
@@ -7,12 +7,11 @@ from typing import Dict, List, Any
7
  from io import BytesIO
8
  import base64
9
 
10
- class EndpointHandler():
11
- def init(self, path=""):
12
  # Initialize the image segmentation pipeline
13
  self.pipeline = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
14
  torch.set_float32_matmul_precision(['high', 'highest'][0])
15
-
16
  self.image_size = (1024, 1024)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -27,7 +26,7 @@ class EndpointHandler():
27
  # Extract the image path from the input data
28
  image_b64 = data.get("inputs", "")
29
  image_data = base64.b64decode(image_b64)
30
- image = Image.open(BytesIO(image_data)).convert("RGB")
31
  input_images = transform_image(image).unsqueeze(0)
32
  print(2)
33
  # Prediction
@@ -35,13 +34,16 @@ class EndpointHandler():
35
  preds = self.pipeline(input_images)[-1].sigmoid()
36
  pred = preds[0].squeeze()
37
  pred_pil = transforms.ToPILImage()(pred)
38
- mask = pred_pil.resize(image.size)
39
- image.putalpha(mask)
 
 
 
40
  print(3)
41
 
42
  # Convert the image with alpha mask to base64
43
  buffered = BytesIO()
44
- image.save(buffered, format="PNG")
45
  base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
46
 
47
  # Return the result as a list of dictionaries
 
7
  from io import BytesIO
8
  import base64
9
 
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
  # Initialize the image segmentation pipeline
13
  self.pipeline = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
14
  torch.set_float32_matmul_precision(['high', 'highest'][0])
 
15
  self.image_size = (1024, 1024)
16
 
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
26
  # Extract the image path from the input data
27
  image_b64 = data.get("inputs", "")
28
  image_data = base64.b64decode(image_b64)
29
+ image = Image.open(BytesIO(image_data)).convert("RGB") # Convert to RGB instead of RGBA
30
  input_images = transform_image(image).unsqueeze(0)
31
  print(2)
32
  # Prediction
 
34
  preds = self.pipeline(input_images)[-1].sigmoid()
35
  pred = preds[0].squeeze()
36
  pred_pil = transforms.ToPILImage()(pred)
37
+
38
+ # Resize the mask to the original image size
39
+ original_image = Image.open(BytesIO(image_data)).convert("RGBA") # Load original RGBA image for alpha
40
+ mask = pred_pil.resize(original_image.size)
41
+ original_image.putalpha(mask)
42
  print(3)
43
 
44
  # Convert the image with alpha mask to base64
45
  buffered = BytesIO()
46
+ original_image.save(buffered, format="PNG")
47
  base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
48
 
49
  # Return the result as a list of dictionaries