Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import torch | |
import gradio as gr | |
import segmentation_models_pytorch as smp | |
from PIL import Image | |
import boto3 | |
import uuid | |
import io | |
from glob import glob | |
from pipeline.ImgOutlier import detect_outliers | |
from pipeline.normalization import align_images | |
# 检测是否在Hugging Face环境中运行 | |
HF_SPACE = os.environ.get('SPACE_ID') is not None | |
# DigitalOcean Spaces上传函数 | |
def upload_mask(image, prefix="mask"): | |
""" | |
将分割掩码图像上传到DigitalOcean Spaces | |
Args: | |
image: PIL Image对象 | |
prefix: 文件名前缀 | |
Returns: | |
上传文件的URL | |
""" | |
try: | |
# 从环境变量获取凭据 | |
do_key = os.environ.get('DO_SPACES_KEY') | |
do_secret = os.environ.get('DO_SPACES_SECRET') | |
do_region = os.environ.get('DO_SPACES_REGION') | |
do_bucket = os.environ.get('DO_SPACES_BUCKET') | |
# 校验凭据是否存在 | |
if not all([do_key, do_secret, do_region, do_bucket]): | |
return "DigitalOcean凭据未设置" | |
# 创建S3客户端 | |
session = boto3.session.Session() | |
client = session.client('s3', | |
region_name=do_region, | |
endpoint_url=f'https://{do_region}.digitaloceanspaces.com', | |
aws_access_key_id=do_key, | |
aws_secret_access_key=do_secret) | |
# 生成唯一文件名 | |
filename = f"{prefix}_{uuid.uuid4().hex}.png" | |
# 将图像转换为字节流 | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
# 上传到Spaces | |
client.upload_fileobj( | |
img_byte_arr, | |
do_bucket, | |
filename, | |
ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'} | |
) | |
# 返回公共URL | |
url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}' | |
return url | |
except Exception as e: | |
print(f"上传失败: {str(e)}") | |
return f"上传错误: {str(e)}" | |
# Global Configuration | |
MODEL_PATHS = { | |
"Metal Marcy": "models/MM_best_model.pth", | |
"Silhouette Jaenette": "models/SJ_best_model.pth" | |
} | |
REFERENCE_VECTOR_PATHS = { | |
"Metal Marcy": "models/MM_mean.npy", | |
"Silhouette Jaenette": "models/SJ_mean.npy" | |
} | |
REFERENCE_IMAGE_DIRS = { | |
"Metal Marcy": "reference_images/MM", | |
"Silhouette Jaenette": "reference_images/SJ" | |
} | |
# Category names and color mapping | |
CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand'] | |
COLORS = [ | |
[0, 0, 0], # background - black | |
[139, 137, 137], # cobbles - dark gray | |
[255, 228, 181], # drysand - light yellow | |
[0, 128, 0], # plant - green | |
[135, 206, 235], # sky - sky blue | |
[0, 0, 255], # water - blue | |
[194, 178, 128] # wetsand - sand brown | |
] | |
# Load model function | |
def load_model(model_path, device="cuda"): | |
try: | |
# 如果在HF环境中,默认使用CPU | |
if HF_SPACE: | |
device = "cpu" # HF Space可能没有GPU | |
elif not torch.cuda.is_available(): | |
device = "cpu" # 本地环境也可能没有GPU | |
model = smp.create_model( | |
"DeepLabV3Plus", | |
encoder_name="efficientnet-b6", | |
in_channels=3, | |
classes=len(CLASSES), | |
encoder_weights=None | |
) | |
state_dict = torch.load(model_path, map_location=device) | |
if all(k.startswith('model.') for k in state_dict.keys()): | |
state_dict = {k[6:]: v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
print(f"模型加载成功: {model_path}") | |
return model | |
except Exception as e: | |
print(f"模型加载失败: {e}") | |
return None | |
# Load reference vector | |
def load_reference_vector(vector_path): | |
try: | |
if not os.path.exists(vector_path): | |
print(f"参考向量文件不存在: {vector_path}") | |
return [] | |
ref_vector = np.load(vector_path) | |
print(f"参考向量加载成功: {vector_path}") | |
return ref_vector | |
except Exception as e: | |
print(f"参考向量加载失败 {vector_path}: {e}") | |
return [] | |
# Load reference image | |
def load_reference_images(ref_dir): | |
try: | |
if not os.path.exists(ref_dir): | |
print(f"参考图像目录不存在: {ref_dir}") | |
os.makedirs(ref_dir, exist_ok=True) | |
return [] | |
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] | |
image_files = [] | |
for ext in image_extensions: | |
image_files.extend(glob(os.path.join(ref_dir, ext))) | |
image_files.sort() | |
reference_images = [] | |
for file in image_files[:4]: | |
img = cv2.imread(file) | |
if img is not None: | |
reference_images.append(img) | |
print(f"从 {ref_dir} 加载了 {len(reference_images)} 张图像") | |
return reference_images | |
except Exception as e: | |
print(f"加载图像失败 {ref_dir}: {e}") | |
return [] | |
# Preprocess the image | |
def preprocess_image(image): | |
if image.shape[2] == 4: | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
orig_h, orig_w = image.shape[:2] | |
image_resized = cv2.resize(image, (1024, 1024)) | |
image_norm = image_resized.astype(np.float32) / 255.0 | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
image_norm = (image_norm - mean) / std | |
image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0) | |
return image_tensor, orig_h, orig_w | |
# Generate segmentation map and visualization | |
def generate_segmentation_map(prediction, orig_h, orig_w): | |
mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8) | |
mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
kernel = np.ones((5, 5), np.uint8) | |
processed_mask = mask_resized.copy() | |
for idx in range(1, len(CLASSES)): | |
class_mask = (mask_resized == idx).astype(np.uint8) | |
dilated_mask = cv2.dilate(class_mask, kernel, iterations=2) | |
dilated_effect = dilated_mask & (mask_resized == 0) | |
processed_mask[dilated_effect > 0] = idx | |
segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) | |
for idx, color in enumerate(COLORS): | |
segmentation_map[processed_mask == idx] = color | |
return segmentation_map | |
# Analysis result HTML | |
def create_analysis_result(mask): | |
total_pixels = mask.size | |
percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1) | |
for i, cls in enumerate(CLASSES)} | |
ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water'] | |
result = "<div style='font-size:18px;font-weight:bold;'>" | |
result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered) | |
result += "</div>" | |
return result | |
# Merge and overlay | |
def create_overlay(image, segmentation_map, alpha=0.5): | |
if image.shape[:2] != segmentation_map.shape[:2]: | |
segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) | |
return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0) | |
# Perform segmentation | |
def perform_segmentation(model, image_bgr): | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
image_tensor, orig_h, orig_w = preprocess_image(image_rgb) | |
with torch.no_grad(): | |
prediction = model(image_tensor.to(device)) | |
seg_map = generate_segmentation_map(prediction, orig_h, orig_w) # RGB | |
overlay = create_overlay(image_rgb, seg_map) | |
mask = prediction.argmax(1).squeeze().cpu().numpy() | |
analysis = create_analysis_result(mask) | |
return seg_map, overlay, analysis | |
# Single image processing | |
def process_coastal_image(location, input_image): | |
if input_image is None: | |
return None, None, "请上传一张图片", "未检测", None | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
model = load_model(MODEL_PATHS[location], device) | |
if model is None: | |
return None, None, f"错误:无法加载模型", "未检测", None | |
ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) | |
ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location]) | |
outlier_status = "未检测" | |
is_outlier = False | |
image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
if len(ref_vector) > 0: | |
filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector) | |
is_outlier = len(filtered) == 0 | |
elif len(ref_images) > 0: | |
filtered, _ = detect_outliers(ref_images, [image_bgr]) | |
is_outlier = len(filtered) == 0 | |
else: | |
print("警告:没有参考图像或参考向量可用于异常检测") | |
is_outlier = False | |
outlier_status = "异常检测: <span style='color:red;font-weight:bold'>未通过</span>" if is_outlier else "异常检测: <span style='color:green;font-weight:bold'>通过</span>" | |
seg_map, overlay, analysis = perform_segmentation(model, image_bgr) | |
# 尝试上传到DigitalOcean Spaces | |
url = "本地存储" | |
try: | |
url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_')) | |
except Exception as e: | |
print(f"上传失败: {e}") | |
url = f"上传错误: {str(e)}" | |
if is_outlier: | |
analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>警告:图像未通过异常检测,结果可能不准确!</div>" + analysis | |
return seg_map, overlay, analysis, outlier_status, url | |
# Spacial Alignment | |
def process_with_alignment(location, reference_image, input_image): | |
if reference_image is None or input_image is None: | |
return None, None, None, None, "请上传参考图像和需要分析的图像", "未处理", None | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
model = load_model(MODEL_PATHS[location], device) | |
if model is None: | |
return None, None, None, None, "错误:无法加载模型", "未处理", None | |
ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR) | |
tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
try: | |
aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)]) | |
aligned_tgt_bgr = aligned[1] | |
except Exception as e: | |
print(f"空间对齐失败: {e}") | |
return None, None, None, None, f"空间对齐失败: {str(e)}", "处理失败", None | |
seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr) | |
# 尝试上传到DigitalOcean Spaces | |
url = "本地存储" | |
try: | |
url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_')) | |
except Exception as e: | |
print(f"上传失败: {e}") | |
url = f"上传错误: {str(e)}" | |
status = "空间对齐: <span style='color:green;font-weight:bold'>完成</span>" | |
ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB) | |
aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB) | |
return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url | |
# Create the Gradio interface | |
def create_interface(): | |
scale = 0.5 | |
disp_w, disp_h = int(1365*scale), int(1024*scale) | |
with gr.Blocks(title="海岸侵蚀分析系统") as demo: | |
gr.Markdown("""# 海岸侵蚀分析系统 | |
上传海岸照片进行分析,包括分割和空间对齐功能。""") | |
with gr.Tabs(): | |
with gr.TabItem("单张图像分割"): | |
with gr.Row(): | |
loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="选择模型", value=list(MODEL_PATHS.keys())[0]) | |
with gr.Row(): | |
inp = gr.Image(label="输入图像", type="numpy", image_mode="RGB") | |
seg = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h) | |
ovl = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h) | |
with gr.Row(): | |
btn1 = gr.Button("执行分割") | |
url1 = gr.Text(label="分割图URL") | |
status1 = gr.HTML(label="异常检测状态") | |
res1 = gr.HTML(label="分析结果") | |
btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1]) | |
with gr.TabItem("空间对齐分割"): | |
with gr.Row(): | |
loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="选择模型", value=list(MODEL_PATHS.keys())[0]) | |
with gr.Row(): | |
ref_img = gr.Image(label="参考图像", type="numpy", image_mode="RGB") | |
tgt_img = gr.Image(label="待分析图像", type="numpy", image_mode="RGB") | |
with gr.Row(): | |
btn2 = gr.Button("执行空间对齐分割") | |
with gr.Row(): | |
orig = gr.Image(label="原始图像", type="numpy", width=disp_w, height=disp_h) | |
aligned = gr.Image(label="对齐后图像", type="numpy", width=disp_w, height=disp_h) | |
with gr.Row(): | |
seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h) | |
ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h) | |
url2 = gr.Text(label="分割图URL") | |
status2 = gr.HTML(label="空间对齐状态") | |
res2 = gr.HTML(label="分析结果") | |
btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2]) | |
return demo | |
if __name__ == "__main__": | |
# 创建必要的目录 | |
for path in ["models", "reference_images/MM", "reference_images/SJ"]: | |
os.makedirs(path, exist_ok=True) | |
# 检查模型文件是否存在 | |
for p in MODEL_PATHS.values(): | |
if not os.path.exists(p): | |
print(f"警告:模型文件 {p} 不存在!") | |
# 检查DigitalOcean凭据是否存在 | |
do_creds = [ | |
os.environ.get('DO_SPACES_KEY'), | |
os.environ.get('DO_SPACES_SECRET'), | |
os.environ.get('DO_SPACES_REGION'), | |
os.environ.get('DO_SPACES_BUCKET') | |
] | |
if not all(do_creds): | |
print("警告:DigitalOcean Spaces凭据不完整,上传功能可能不可用") | |
# 创建并启动界面 | |
demo = create_interface() | |
# 在HF环境中使用适当的启动配置 | |
if HF_SPACE: | |
demo.launch() | |
else: | |
demo.launch(share=True) |