import os import cv2 import numpy as np import torch import gradio as gr import segmentation_models_pytorch as smp from PIL import Image import boto3 import uuid import io from glob import glob from pipeline.ImgOutlier import detect_outliers from pipeline.normalization import align_images # 检测是否在Hugging Face环境中运行 HF_SPACE = os.environ.get('SPACE_ID') is not None # DigitalOcean Spaces上传函数 def upload_mask(image, prefix="mask"): """ 将分割掩码图像上传到DigitalOcean Spaces Args: image: PIL Image对象 prefix: 文件名前缀 Returns: 上传文件的URL """ try: # 从环境变量获取凭据 do_key = os.environ.get('DO_SPACES_KEY') do_secret = os.environ.get('DO_SPACES_SECRET') do_region = os.environ.get('DO_SPACES_REGION') do_bucket = os.environ.get('DO_SPACES_BUCKET') # 校验凭据是否存在 if not all([do_key, do_secret, do_region, do_bucket]): return "DigitalOcean凭据未设置" # 创建S3客户端 session = boto3.session.Session() client = session.client('s3', region_name=do_region, endpoint_url=f'https://{do_region}.digitaloceanspaces.com', aws_access_key_id=do_key, aws_secret_access_key=do_secret) # 生成唯一文件名 filename = f"{prefix}_{uuid.uuid4().hex}.png" # 将图像转换为字节流 img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) # 上传到Spaces client.upload_fileobj( img_byte_arr, do_bucket, filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'} ) # 返回公共URL url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}' return url except Exception as e: print(f"上传失败: {str(e)}") return f"上传错误: {str(e)}" # Global Configuration MODEL_PATHS = { "Metal Marcy": "models/MM_best_model.pth", "Silhouette Jaenette": "models/SJ_best_model.pth" } REFERENCE_VECTOR_PATHS = { "Metal Marcy": "models/MM_mean.npy", "Silhouette Jaenette": "models/SJ_mean.npy" } REFERENCE_IMAGE_DIRS = { "Metal Marcy": "reference_images/MM", "Silhouette Jaenette": "reference_images/SJ" } # Category names and color mapping CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand'] COLORS = [ [0, 0, 0], # background - black [139, 137, 137], # cobbles - dark gray [255, 228, 181], # drysand - light yellow [0, 128, 0], # plant - green [135, 206, 235], # sky - sky blue [0, 0, 255], # water - blue [194, 178, 128] # wetsand - sand brown ] # Load model function def load_model(model_path, device="cuda"): try: # 如果在HF环境中,默认使用CPU if HF_SPACE: device = "cpu" # HF Space可能没有GPU elif not torch.cuda.is_available(): device = "cpu" # 本地环境也可能没有GPU model = smp.create_model( "DeepLabV3Plus", encoder_name="efficientnet-b6", in_channels=3, classes=len(CLASSES), encoder_weights=None ) state_dict = torch.load(model_path, map_location=device) if all(k.startswith('model.') for k in state_dict.keys()): state_dict = {k[6:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(device) model.eval() print(f"模型加载成功: {model_path}") return model except Exception as e: print(f"模型加载失败: {e}") return None # Load reference vector def load_reference_vector(vector_path): try: if not os.path.exists(vector_path): print(f"参考向量文件不存在: {vector_path}") return [] ref_vector = np.load(vector_path) print(f"参考向量加载成功: {vector_path}") return ref_vector except Exception as e: print(f"参考向量加载失败 {vector_path}: {e}") return [] # Load reference image def load_reference_images(ref_dir): try: if not os.path.exists(ref_dir): print(f"参考图像目录不存在: {ref_dir}") os.makedirs(ref_dir, exist_ok=True) return [] image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] image_files = [] for ext in image_extensions: image_files.extend(glob(os.path.join(ref_dir, ext))) image_files.sort() reference_images = [] for file in image_files[:4]: img = cv2.imread(file) if img is not None: reference_images.append(img) print(f"从 {ref_dir} 加载了 {len(reference_images)} 张图像") return reference_images except Exception as e: print(f"加载图像失败 {ref_dir}: {e}") return [] # Preprocess the image def preprocess_image(image): if image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) orig_h, orig_w = image.shape[:2] image_resized = cv2.resize(image, (1024, 1024)) image_norm = image_resized.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image_norm = (image_norm - mean) / std image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0) return image_tensor, orig_h, orig_w # Generate segmentation map and visualization def generate_segmentation_map(prediction, orig_h, orig_w): mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8) mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) kernel = np.ones((5, 5), np.uint8) processed_mask = mask_resized.copy() for idx in range(1, len(CLASSES)): class_mask = (mask_resized == idx).astype(np.uint8) dilated_mask = cv2.dilate(class_mask, kernel, iterations=2) dilated_effect = dilated_mask & (mask_resized == 0) processed_mask[dilated_effect > 0] = idx segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) for idx, color in enumerate(COLORS): segmentation_map[processed_mask == idx] = color return segmentation_map # Analysis result HTML def create_analysis_result(mask): total_pixels = mask.size percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1) for i, cls in enumerate(CLASSES)} ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water'] result = "
" result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered) result += "
" return result # Merge and overlay def create_overlay(image, segmentation_map, alpha=0.5): if image.shape[:2] != segmentation_map.shape[:2]: segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0) # Perform segmentation def perform_segmentation(model, image_bgr): device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) image_tensor, orig_h, orig_w = preprocess_image(image_rgb) with torch.no_grad(): prediction = model(image_tensor.to(device)) seg_map = generate_segmentation_map(prediction, orig_h, orig_w) # RGB overlay = create_overlay(image_rgb, seg_map) mask = prediction.argmax(1).squeeze().cpu().numpy() analysis = create_analysis_result(mask) return seg_map, overlay, analysis # Single image processing def process_coastal_image(location, input_image): if input_image is None: return None, None, "请上传一张图片", "未检测", None device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" model = load_model(MODEL_PATHS[location], device) if model is None: return None, None, f"错误:无法加载模型", "未检测", None ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location]) outlier_status = "未检测" is_outlier = False image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) if len(ref_vector) > 0: filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector) is_outlier = len(filtered) == 0 elif len(ref_images) > 0: filtered, _ = detect_outliers(ref_images, [image_bgr]) is_outlier = len(filtered) == 0 else: print("警告:没有参考图像或参考向量可用于异常检测") is_outlier = False outlier_status = "异常检测: 未通过" if is_outlier else "异常检测: 通过" seg_map, overlay, analysis = perform_segmentation(model, image_bgr) # 尝试上传到DigitalOcean Spaces url = "本地存储" try: url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_')) except Exception as e: print(f"上传失败: {e}") url = f"上传错误: {str(e)}" if is_outlier: analysis = "
警告:图像未通过异常检测,结果可能不准确!
" + analysis return seg_map, overlay, analysis, outlier_status, url # Spacial Alignment def process_with_alignment(location, reference_image, input_image): if reference_image is None or input_image is None: return None, None, None, None, "请上传参考图像和需要分析的图像", "未处理", None device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" model = load_model(MODEL_PATHS[location], device) if model is None: return None, None, None, None, "错误:无法加载模型", "未处理", None ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR) tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) try: aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)]) aligned_tgt_bgr = aligned[1] except Exception as e: print(f"空间对齐失败: {e}") return None, None, None, None, f"空间对齐失败: {str(e)}", "处理失败", None seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr) # 尝试上传到DigitalOcean Spaces url = "本地存储" try: url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_')) except Exception as e: print(f"上传失败: {e}") url = f"上传错误: {str(e)}" status = "空间对齐: 完成" ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB) aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB) return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url # Create the Gradio interface def create_interface(): scale = 0.5 disp_w, disp_h = int(1365*scale), int(1024*scale) with gr.Blocks(title="海岸侵蚀分析系统") as demo: gr.Markdown("""# 海岸侵蚀分析系统 上传海岸照片进行分析,包括分割和空间对齐功能。""") with gr.Tabs(): with gr.TabItem("单张图像分割"): with gr.Row(): loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="选择模型", value=list(MODEL_PATHS.keys())[0]) with gr.Row(): inp = gr.Image(label="输入图像", type="numpy", image_mode="RGB") seg = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h) ovl = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h) with gr.Row(): btn1 = gr.Button("执行分割") url1 = gr.Text(label="分割图URL") status1 = gr.HTML(label="异常检测状态") res1 = gr.HTML(label="分析结果") btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1]) with gr.TabItem("空间对齐分割"): with gr.Row(): loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="选择模型", value=list(MODEL_PATHS.keys())[0]) with gr.Row(): ref_img = gr.Image(label="参考图像", type="numpy", image_mode="RGB") tgt_img = gr.Image(label="待分析图像", type="numpy", image_mode="RGB") with gr.Row(): btn2 = gr.Button("执行空间对齐分割") with gr.Row(): orig = gr.Image(label="原始图像", type="numpy", width=disp_w, height=disp_h) aligned = gr.Image(label="对齐后图像", type="numpy", width=disp_w, height=disp_h) with gr.Row(): seg2 = gr.Image(label="分割图像", type="numpy", width=disp_w, height=disp_h) ovl2 = gr.Image(label="叠加图像", type="numpy", width=disp_w, height=disp_h) url2 = gr.Text(label="分割图URL") status2 = gr.HTML(label="空间对齐状态") res2 = gr.HTML(label="分析结果") btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2]) return demo if __name__ == "__main__": # 创建必要的目录 for path in ["models", "reference_images/MM", "reference_images/SJ"]: os.makedirs(path, exist_ok=True) # 检查模型文件是否存在 for p in MODEL_PATHS.values(): if not os.path.exists(p): print(f"警告:模型文件 {p} 不存在!") # 检查DigitalOcean凭据是否存在 do_creds = [ os.environ.get('DO_SPACES_KEY'), os.environ.get('DO_SPACES_SECRET'), os.environ.get('DO_SPACES_REGION'), os.environ.get('DO_SPACES_BUCKET') ] if not all(do_creds): print("警告:DigitalOcean Spaces凭据不完整,上传功能可能不可用") # 创建并启动界面 demo = create_interface() # 在HF环境中使用适当的启动配置 if HF_SPACE: demo.launch() else: demo.launch(share=True)