ritz26's picture
Update app.py
ed0aacc verified
from flask import Flask, render_template, request
from PIL import Image
import os
import torch
import cv2
import mediapipe as mp
from transformers import SamModel, SamProcessor
from diffusers.utils import load_image
from torchvision import transforms
import tempfile
app = Flask(__name__)
# Use temporary directories for uploads and outputs
UPLOAD_FOLDER = '/tmp/uploads'
OUTPUT_FOLDER = '/tmp/outputs'
# Ensure folders exist
try:
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Also create static directories for serving files
os.makedirs('static/uploads', exist_ok=True)
os.makedirs('static/outputs', exist_ok=True)
except PermissionError as e:
print(f"Permission denied for creating directories: {e}")
# Load model once at startup
try:
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50")
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
# Pose function
def get_shoulder_coordinates(image_path):
try:
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
static_image_mode=True,
model_complexity=2,
enable_segmentation=False,
min_detection_confidence=0.5
)
image = cv2.imread(image_path)
if image is None:
print(f"Could not load image from {image_path}")
return None
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
height, width, _ = image.shape
landmarks = results.pose_landmarks.landmark
left_shoulder = (
int(landmarks[11].x * width),
int(landmarks[11].y * height)
)
right_shoulder = (
int(landmarks[12].x * width),
int(landmarks[12].y * height)
)
print(f"Left shoulder: {left_shoulder}")
print(f"Right shoulder: {right_shoulder}")
return left_shoulder, right_shoulder
else:
print("No pose landmarks detected")
return None
except Exception as e:
print(f"Error in pose detection: {e}")
return None
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
try:
person_file = request.files.get('person_image')
tshirt_file = request.files.get('tshirt_image')
if not person_file or not tshirt_file:
return "Please upload both person and t-shirt images."
# Save files to temporary directory
person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
person_file.save(person_path)
tshirt_file.save(tshirt_path)
# Run your model
coordinates = get_shoulder_coordinates(person_path)
if coordinates is None:
return "No pose detected. Please try with a different image where the person's shoulders are clearly visible."
img = load_image(person_path)
new_tshirt = load_image(tshirt_path)
left_shoulder, right_shoulder = coordinates
input_points = [[[left_shoulder[0], left_shoulder[1]], [right_shoulder[0], right_shoulder[1]]]]
inputs = processor(img, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask_tensor = masks[0][0][2].to(dtype=torch.uint8)
mask = transforms.ToPILImage()(mask_tensor * 255)
new_tshirt = new_tshirt.resize(img.size, Image.LANCZOS)
img_with_new_tshirt = Image.composite(new_tshirt, img, mask)
# Save result to both temp and static directories
result_path_temp = os.path.join(OUTPUT_FOLDER, 'result.jpg')
result_path_static = os.path.join('static/outputs', 'result.jpg')
img_with_new_tshirt.save(result_path_temp)
img_with_new_tshirt.save(result_path_static)
return render_template('index.html', result_img='outputs/result.jpg')
except Exception as e:
print(f"Error processing request: {e}")
return f"Error processing images: {str(e)}"
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=6000)