|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from torchvision import transforms |
|
from transformers import AutoModelForImageSegmentation |
|
from typing import Dict, List, Any |
|
from io import BytesIO |
|
import base64 |
|
|
|
class EndpointHandler: |
|
def __init__(self): |
|
|
|
self.pipeline = AutoModelForImageSegmentation.from_pretrained('.', trust_remote_code=True) |
|
torch.set_float32_matmul_precision(['high', 'highest'][0]) |
|
self.image_size = (1024, 1024) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
self.pipeline.eval() |
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
print(1) |
|
|
|
image_b64 = data.get("inputs", "") |
|
image_data = base64.b64decode(image_b64) |
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
input_images = transform_image(image).unsqueeze(0) |
|
print(2) |
|
|
|
with torch.no_grad(): |
|
preds = self.pipeline(input_images)[-1].sigmoid() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
|
|
|
|
original_image = Image.open(BytesIO(image_data)).convert("RGBA") |
|
mask = pred_pil.resize(original_image.size) |
|
original_image.putalpha(mask) |
|
print(3) |
|
|
|
|
|
buffered = BytesIO() |
|
original_image.save(buffered, format="PNG") |
|
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
return [{"image": base64_image}] |