briaa-2.0-aina-bg-rmv / handler.py
udman99's picture
Update handler.py
12f2dcd verified
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self):
# Initialize the image segmentation pipeline
self.pipeline = AutoModelForImageSegmentation.from_pretrained('.', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
self.image_size = (1024, 1024)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
self.pipeline.eval()
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print(1)
# Extract the image path from the input data
image_b64 = data.get("inputs", "")
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data)).convert("RGB") # Convert to RGB instead of RGBA
input_images = transform_image(image).unsqueeze(0)
print(2)
# Prediction
with torch.no_grad():
preds = self.pipeline(input_images)[-1].sigmoid()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
# Resize the mask to the original image size
original_image = Image.open(BytesIO(image_data)).convert("RGBA") # Load original RGBA image for alpha
mask = pred_pil.resize(original_image.size)
original_image.putalpha(mask)
print(3)
# Convert the image with alpha mask to base64
buffered = BytesIO()
original_image.save(buffered, format="PNG")
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Return the result as a list of dictionaries
return [{"image": base64_image}]