import gradio as gr import os import random import httpx import asyncio from dataclasses import dataclass, field from typing import Any from dotenv import load_dotenv # 加载.env文件 load_dotenv() # 常量定义 HTTP_STATUS_CENSORED = 451 HTTP_STATUS_OK = 200 MAX_SEED = 2147483647 MAX_IMAGE_SIZE = 2048 MIN_IMAGE_SIZE = 256 # 调试模式 DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true" # 模型配置映射 MODEL_CONFIGS = { "base": "results_cosine_2e-4_bs64_infallssssuum/checkpoint-e4_s82000/consolidated.00-of-01.pth", "aesthetics_fixed": "results_cosine_2e-4_bs64_infallssssuumnnnnaaa/checkpoint-e54_s43058/consolidated.00-of-01.pth", "ep3": "results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s63858/consolidated.00-of-01.pth", "ep3latest": "autodl-fs/lumina_hf/results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s83000/consolidated.00-of-01.pth" } def validate_dimensions(width: int, height: int) -> tuple[int, int]: """验证并调整图片尺寸""" # 确保尺寸是32的倍数 width = max(MIN_IMAGE_SIZE, min(width, MAX_IMAGE_SIZE)) height = max(MIN_IMAGE_SIZE, min(height, MAX_IMAGE_SIZE)) width = (width // 32) * 32 height = (height // 32) * 32 return width, height @dataclass class LuminaConfig: """Lumina模型配置""" model_name: str | None = None cfg: float | None = None step: int | None = None @dataclass class ImageGenerationConfig: """图像生成配置""" prompts: list[dict[str, Any]] = field(default_factory=list) width: int = 1024 height: int = 1024 seed: int | None = None use_polish: bool = False is_lumina: bool = True lumina_config: LuminaConfig = field(default_factory=LuminaConfig) class ImageClient: """图像生成客户端""" def __init__(self) -> None: # 使用环境变量中的API_TOKEN self.x_token = os.environ.get("API_TOKEN", "") if not self.x_token: raise ValueError("环境变量中未设置API_TOKEN") # API端点 self.lumina_api_url = "https://ops.api.talesofai.cn/v3/make_image" self.lumina_task_status_url = "https://ops.api.talesofai.cn/v1/artifact/task/{task_uuid}" # 轮询配置(增加超时时间以应对服务器负载) self.max_polling_attempts = 100 # 增加到100次 self.polling_interval = 3.0 # 保持3秒间隔 # 总超时时间:100 × 3.0 = 300秒 = 5分钟 # 默认请求头 self.default_headers = { "Content-Type": "application/json", "x-platform": "nieta-app/web", "X-Token": self.x_token, } def _prepare_prompt_data(self, prompt: str, negative_prompt: str = "") -> list[dict[str, Any]]: """准备提示词数据""" prompts = [ { "type": "freetext", "value": prompt, "weight": 1.0 } ] if negative_prompt: prompts.append({ "type": "freetext", "value": negative_prompt, "weight": -1.0 }) # 添加Lumina元素 prompts.append({ "type": "elementum", "value": "b5edccfe-46a2-4a14-a8ff-f4d430343805", "uuid": "b5edccfe-46a2-4a14-a8ff-f4d430343805", "weight": 1.0, "name": "lumina1", "img_url": "https://oss.talesofai.cn/picture_s/1y7f53e6itfn_0.jpeg", "domain": "", "parent": "", "label": None, "sort_index": 0, "status": "IN_USE", "polymorphi_values": {}, "sub_type": None, }) return prompts def _build_payload(self, config: ImageGenerationConfig) -> dict[str, Any]: """构建API请求载荷""" payload = { "storyId": "", "jobType": "universal", "width": config.width, "height": config.height, "rawPrompt": config.prompts, "seed": config.seed, "meta": {"entrance": "PICTURE,PURE"}, "context_model_series": None, "negative_freetext": "", "advanced_translator": config.use_polish, } if config.is_lumina: client_args = {} if config.lumina_config.model_name: client_args["ckpt_name"] = config.lumina_config.model_name if config.lumina_config.cfg is not None: client_args["cfg"] = str(config.lumina_config.cfg) if config.lumina_config.step is not None: client_args["steps"] = str(config.lumina_config.step) if client_args: payload["client_args"] = client_args return payload async def _poll_task_status(self, task_uuid: str) -> dict[str, Any]: """轮询任务状态""" async with httpx.AsyncClient(timeout=30.0) as client: for _ in range(self.max_polling_attempts): response = await client.get( self.lumina_task_status_url.format(task_uuid=task_uuid), headers=self.default_headers ) if response.status_code != HTTP_STATUS_OK: return { "success": False, "error": f"获取任务状态失败: {response.status_code} - {response.text}" } # 解析JSON响应 try: result = response.json() except Exception as e: return { "success": False, "error": f"任务状态响应解析失败: {response.text[:500]}" } # 使用正确的字段名(根据model_studio的实现) task_status = result.get("task_status") if task_status == "SUCCESS": # 从artifacts数组中获取图片URL artifacts = result.get("artifacts", []) if artifacts and len(artifacts) > 0: image_url = artifacts[0].get("url") if image_url: return { "success": True, "image_url": image_url } return { "success": False, "error": "无法从结果中提取图像URL" } elif task_status == "FAILURE": return { "success": False, "error": result.get("error", "任务执行失败") } elif task_status == "ILLEGAL_IMAGE": return { "success": False, "error": "图片不合规" } elif task_status == "TIMEOUT": return { "success": False, "error": "任务超时" } await asyncio.sleep(self.polling_interval) return { "success": False, "error": "⏳ 生图任务超时(5分钟),服务器可能正在处理大量请求,请稍后重试" } async def generate_image(self, prompt: str, negative_prompt: str, seed: int, width: int, height: int, cfg: float, steps: int, model_name: str = "base") -> tuple[str | None, str | None]: """生成图片""" try: # 获取模型路径 model_path = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["base"]) # 准备配置 config = ImageGenerationConfig( prompts=self._prepare_prompt_data(prompt, negative_prompt), width=width, height=height, seed=seed, is_lumina=True, lumina_config=LuminaConfig( model_name=model_path, cfg=cfg, step=steps ) ) # 发送生成请求 async with httpx.AsyncClient(timeout=300.0) as client: payload = self._build_payload(config) if DEBUG_MODE: print(f"DEBUG: 发送API请求到 {self.lumina_api_url}") print(f"DEBUG: 请求载荷: {payload}") response = await client.post( self.lumina_api_url, json=payload, headers=self.default_headers ) if DEBUG_MODE: print(f"DEBUG: API响应状态码: {response.status_code}") print(f"DEBUG: API响应内容: {response.text[:1000]}") if response.status_code == HTTP_STATUS_CENSORED: return None, "内容不合规" # 处理并发限制错误 if response.status_code == 433: return None, "⏳ 服务器正忙,同时生成的图片数量已达上限,请稍后重试" if response.status_code != HTTP_STATUS_OK: return None, f"API请求失败: {response.status_code} - {response.text}" # API直接返回UUID字符串(根据model_studio的实现) content = response.text.strip() task_uuid = content.replace('"', "") if DEBUG_MODE: print(f"DEBUG: API返回UUID: {task_uuid}") if not task_uuid: return None, f"未获取到任务ID,API响应: {response.text}" # 轮询任务状态 result = await self._poll_task_status(task_uuid) if result["success"]: return result["image_url"], None else: return None, result["error"] except Exception as e: return None, f"生成图片时发生错误: {str(e)}" # 创建图片生成客户端实例 image_client = ImageClient() # 示例提示词 examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ] css = """ .main-container { max-width: 1400px !important; margin: 0 auto !important; padding: 20px !important; } .left-panel { background: linear-gradient(145deg, #f8f9fa, #e9ecef) !important; border-radius: 16px !important; padding: 24px !important; box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; border: 1px solid rgba(255,255,255,0.2) !important; } .right-panel { background: linear-gradient(145deg, #ffffff, #f8f9fa) !important; border-radius: 16px !important; padding: 24px !important; box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; border: 1px solid rgba(255,255,255,0.2) !important; display: flex !important; flex-direction: column !important; align-items: center !important; justify-content: flex-start !important; min-height: 600px !important; } #main-prompt textarea { min-height: 180px !important; font-size: 15px !important; line-height: 1.6 !important; padding: 16px !important; border-radius: 12px !important; border: 2px solid #e9ecef !important; transition: all 0.3s ease !important; background: white !important; } #main-prompt textarea:focus { border-color: #4f46e5 !important; box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1) !important; } .run-button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; border-radius: 12px !important; padding: 12px 32px !important; font-weight: 600 !important; font-size: 16px !important; color: white !important; box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; transition: all 0.3s ease !important; } .run-button:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; } .settings-section { background: white !important; border-radius: 12px !important; padding: 16px !important; margin-top: 12px !important; box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important; border: 1px solid #e9ecef !important; } .settings-title { font-size: 18px !important; font-weight: 600 !important; color: #374151 !important; margin-bottom: 12px !important; padding-bottom: 6px !important; border-bottom: 2px solid #e9ecef !important; } .right-panel .settings-title { margin-bottom: 8px !important; padding-bottom: 4px !important; font-size: 16px !important; } .right-panel .block { min-height: unset !important; height: auto !important; flex: none !important; } .right-panel .html-container { padding: 0 !important; margin: 0 !important; } .slider-container .wrap { background: #f8f9fa !important; border-radius: 8px !important; padding: 6px !important; margin: 2px 0 !important; } .settings-section .block { margin: 4px 0 !important; } .settings-section .row { margin: 6px 0 !important; } .settings-section .form { gap: 4px !important; } .settings-section .html-container { padding: 0 !important; margin: 8px 0 4px 0 !important; } .result-image { border-radius: 16px !important; box-shadow: 0 8px 32px rgba(0,0,0,0.15) !important; max-width: 100% !important; height: auto !important; min-height: 400px !important; width: 100% !important; } .result-image img { border-radius: 16px !important; object-fit: contain !important; max-width: 100% !important; max-height: 600px !important; width: auto !important; height: auto !important; } .examples-container { margin-top: 20px !important; } .title-header { text-align: center !important; margin-bottom: 30px !important; padding: 20px !important; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; border-radius: 16px !important; box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important; } .title-header h1 { font-size: 28px !important; font-weight: 700 !important; margin: 0 !important; text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important; } """ async def infer( prompt, seed, randomize_seed, width, height, cfg, steps, model_name, progress=gr.Progress(track_tqdm=True), ): if not prompt.strip(): raise gr.Error("提示词不能为空") if randomize_seed: seed = random.randint(0, MAX_SEED) # 验证并调整尺寸 width, height = validate_dimensions(width, height) # 验证其他参数 if not 1.0 <= cfg <= 20.0: raise gr.Error("CFG Scale 必须在 1.0 到 20.0 之间") if not 1 <= steps <= 50: raise gr.Error("Steps 必须在 1 到 50 之间") image_url, error = await image_client.generate_image( prompt=prompt, negative_prompt="", seed=seed, width=width, height=height, cfg=cfg, steps=steps, model_name=model_name ) if error: raise gr.Error(error) return image_url, seed with gr.Blocks(css=css) as demo: with gr.Column(elem_classes=["main-container"]): # 标题区域 with gr.Row(elem_classes=["title-header"]): gr.HTML("