Aowrow's picture
init commit
d4a836b
raw
history blame
20.5 kB
import gradio as gr
import os
import random
import httpx
import asyncio
from dataclasses import dataclass, field
from typing import Any
from dotenv import load_dotenv
# 加载.env文件
load_dotenv()
# 常量定义
HTTP_STATUS_CENSORED = 451
HTTP_STATUS_OK = 200
MAX_SEED = 2147483647
MAX_IMAGE_SIZE = 2048
MIN_IMAGE_SIZE = 256
# 调试模式
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true"
# 模型配置映射
MODEL_CONFIGS = {
"base": "results_cosine_2e-4_bs64_infallssssuum/checkpoint-e4_s82000/consolidated.00-of-01.pth",
"aesthetics_fixed": "results_cosine_2e-4_bs64_infallssssuumnnnnaaa/checkpoint-e54_s43058/consolidated.00-of-01.pth",
"ep3": "results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s63858/consolidated.00-of-01.pth",
"ep3latest": "autodl-fs/lumina_hf/results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s83000/consolidated.00-of-01.pth"
}
def validate_dimensions(width: int, height: int) -> tuple[int, int]:
"""验证并调整图片尺寸"""
# 确保尺寸是32的倍数
width = max(MIN_IMAGE_SIZE, min(width, MAX_IMAGE_SIZE))
height = max(MIN_IMAGE_SIZE, min(height, MAX_IMAGE_SIZE))
width = (width // 32) * 32
height = (height // 32) * 32
return width, height
@dataclass
class LuminaConfig:
"""Lumina模型配置"""
model_name: str | None = None
cfg: float | None = None
step: int | None = None
@dataclass
class ImageGenerationConfig:
"""图像生成配置"""
prompts: list[dict[str, Any]] = field(default_factory=list)
width: int = 1024
height: int = 1024
seed: int | None = None
use_polish: bool = False
is_lumina: bool = True
lumina_config: LuminaConfig = field(default_factory=LuminaConfig)
class ImageClient:
"""图像生成客户端"""
def __init__(self) -> None:
# 使用环境变量中的API_TOKEN
self.x_token = os.environ.get("API_TOKEN", "")
if not self.x_token:
raise ValueError("环境变量中未设置API_TOKEN")
# API端点
self.lumina_api_url = "https://ops.api.talesofai.cn/v3/make_image"
self.lumina_task_status_url = "https://ops.api.talesofai.cn/v1/artifact/task/{task_uuid}"
# 轮询配置(增加超时时间以应对服务器负载)
self.max_polling_attempts = 100 # 增加到100次
self.polling_interval = 3.0 # 保持3秒间隔
# 总超时时间:100 × 3.0 = 300秒 = 5分钟
# 默认请求头
self.default_headers = {
"Content-Type": "application/json",
"x-platform": "nieta-app/web",
"X-Token": self.x_token,
}
def _prepare_prompt_data(self, prompt: str, negative_prompt: str = "") -> list[dict[str, Any]]:
"""准备提示词数据"""
prompts = [
{
"type": "freetext",
"value": prompt,
"weight": 1.0
}
]
if negative_prompt:
prompts.append({
"type": "freetext",
"value": negative_prompt,
"weight": -1.0
})
# 添加Lumina元素
prompts.append({
"type": "elementum",
"value": "b5edccfe-46a2-4a14-a8ff-f4d430343805",
"uuid": "b5edccfe-46a2-4a14-a8ff-f4d430343805",
"weight": 1.0,
"name": "lumina1",
"img_url": "https://oss.talesofai.cn/picture_s/1y7f53e6itfn_0.jpeg",
"domain": "",
"parent": "",
"label": None,
"sort_index": 0,
"status": "IN_USE",
"polymorphi_values": {},
"sub_type": None,
})
return prompts
def _build_payload(self, config: ImageGenerationConfig) -> dict[str, Any]:
"""构建API请求载荷"""
payload = {
"storyId": "",
"jobType": "universal",
"width": config.width,
"height": config.height,
"rawPrompt": config.prompts,
"seed": config.seed,
"meta": {"entrance": "PICTURE,PURE"},
"context_model_series": None,
"negative_freetext": "",
"advanced_translator": config.use_polish,
}
if config.is_lumina:
client_args = {}
if config.lumina_config.model_name:
client_args["ckpt_name"] = config.lumina_config.model_name
if config.lumina_config.cfg is not None:
client_args["cfg"] = str(config.lumina_config.cfg)
if config.lumina_config.step is not None:
client_args["steps"] = str(config.lumina_config.step)
if client_args:
payload["client_args"] = client_args
return payload
async def _poll_task_status(self, task_uuid: str) -> dict[str, Any]:
"""轮询任务状态"""
async with httpx.AsyncClient(timeout=30.0) as client:
for _ in range(self.max_polling_attempts):
response = await client.get(
self.lumina_task_status_url.format(task_uuid=task_uuid),
headers=self.default_headers
)
if response.status_code != HTTP_STATUS_OK:
return {
"success": False,
"error": f"获取任务状态失败: {response.status_code} - {response.text}"
}
# 解析JSON响应
try:
result = response.json()
except Exception as e:
return {
"success": False,
"error": f"任务状态响应解析失败: {response.text[:500]}"
}
# 使用正确的字段名(根据model_studio的实现)
task_status = result.get("task_status")
if task_status == "SUCCESS":
# 从artifacts数组中获取图片URL
artifacts = result.get("artifacts", [])
if artifacts and len(artifacts) > 0:
image_url = artifacts[0].get("url")
if image_url:
return {
"success": True,
"image_url": image_url
}
return {
"success": False,
"error": "无法从结果中提取图像URL"
}
elif task_status == "FAILURE":
return {
"success": False,
"error": result.get("error", "任务执行失败")
}
elif task_status == "ILLEGAL_IMAGE":
return {
"success": False,
"error": "图片不合规"
}
elif task_status == "TIMEOUT":
return {
"success": False,
"error": "任务超时"
}
await asyncio.sleep(self.polling_interval)
return {
"success": False,
"error": "⏳ 生图任务超时(5分钟),服务器可能正在处理大量请求,请稍后重试"
}
async def generate_image(self, prompt: str, negative_prompt: str, seed: int, width: int, height: int, cfg: float, steps: int, model_name: str = "base") -> tuple[str | None, str | None]:
"""生成图片"""
try:
# 获取模型路径
model_path = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["base"])
# 准备配置
config = ImageGenerationConfig(
prompts=self._prepare_prompt_data(prompt, negative_prompt),
width=width,
height=height,
seed=seed,
is_lumina=True,
lumina_config=LuminaConfig(
model_name=model_path,
cfg=cfg,
step=steps
)
)
# 发送生成请求
async with httpx.AsyncClient(timeout=300.0) as client:
payload = self._build_payload(config)
if DEBUG_MODE:
print(f"DEBUG: 发送API请求到 {self.lumina_api_url}")
print(f"DEBUG: 请求载荷: {payload}")
response = await client.post(
self.lumina_api_url,
json=payload,
headers=self.default_headers
)
if DEBUG_MODE:
print(f"DEBUG: API响应状态码: {response.status_code}")
print(f"DEBUG: API响应内容: {response.text[:1000]}")
if response.status_code == HTTP_STATUS_CENSORED:
return None, "内容不合规"
# 处理并发限制错误
if response.status_code == 433:
return None, "⏳ 服务器正忙,同时生成的图片数量已达上限,请稍后重试"
if response.status_code != HTTP_STATUS_OK:
return None, f"API请求失败: {response.status_code} - {response.text}"
# API直接返回UUID字符串(根据model_studio的实现)
content = response.text.strip()
task_uuid = content.replace('"', "")
if DEBUG_MODE:
print(f"DEBUG: API返回UUID: {task_uuid}")
if not task_uuid:
return None, f"未获取到任务ID,API响应: {response.text}"
# 轮询任务状态
result = await self._poll_task_status(task_uuid)
if result["success"]:
return result["image_url"], None
else:
return None, result["error"]
except Exception as e:
return None, f"生成图片时发生错误: {str(e)}"
# 创建图片生成客户端实例
image_client = ImageClient()
# 示例提示词
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
.main-container {
max-width: 1400px !important;
margin: 0 auto !important;
padding: 20px !important;
}
.left-panel {
background: linear-gradient(145deg, #f8f9fa, #e9ecef) !important;
border-radius: 16px !important;
padding: 24px !important;
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important;
border: 1px solid rgba(255,255,255,0.2) !important;
}
.right-panel {
background: linear-gradient(145deg, #ffffff, #f8f9fa) !important;
border-radius: 16px !important;
padding: 24px !important;
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important;
border: 1px solid rgba(255,255,255,0.2) !important;
display: flex !important;
flex-direction: column !important;
align-items: center !important;
justify-content: flex-start !important;
min-height: 600px !important;
}
#main-prompt textarea {
min-height: 180px !important;
font-size: 15px !important;
line-height: 1.6 !important;
padding: 16px !important;
border-radius: 12px !important;
border: 2px solid #e9ecef !important;
transition: all 0.3s ease !important;
background: white !important;
}
#main-prompt textarea:focus {
border-color: #4f46e5 !important;
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1) !important;
}
.run-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 12px !important;
padding: 12px 32px !important;
font-weight: 600 !important;
font-size: 16px !important;
color: white !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.run-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
.settings-section {
background: white !important;
border-radius: 12px !important;
padding: 16px !important;
margin-top: 12px !important;
box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important;
border: 1px solid #e9ecef !important;
}
.settings-title {
font-size: 18px !important;
font-weight: 600 !important;
color: #374151 !important;
margin-bottom: 12px !important;
padding-bottom: 6px !important;
border-bottom: 2px solid #e9ecef !important;
}
.right-panel .settings-title {
margin-bottom: 8px !important;
padding-bottom: 4px !important;
font-size: 16px !important;
}
.right-panel .block {
min-height: unset !important;
height: auto !important;
flex: none !important;
}
.right-panel .html-container {
padding: 0 !important;
margin: 0 !important;
}
.slider-container .wrap {
background: #f8f9fa !important;
border-radius: 8px !important;
padding: 6px !important;
margin: 2px 0 !important;
}
.settings-section .block {
margin: 4px 0 !important;
}
.settings-section .row {
margin: 6px 0 !important;
}
.settings-section .form {
gap: 4px !important;
}
.settings-section .html-container {
padding: 0 !important;
margin: 8px 0 4px 0 !important;
}
.result-image {
border-radius: 16px !important;
box-shadow: 0 8px 32px rgba(0,0,0,0.15) !important;
max-width: 100% !important;
height: auto !important;
min-height: 400px !important;
width: 100% !important;
}
.result-image img {
border-radius: 16px !important;
object-fit: contain !important;
max-width: 100% !important;
max-height: 600px !important;
width: auto !important;
height: auto !important;
}
.examples-container {
margin-top: 20px !important;
}
.title-header {
text-align: center !important;
margin-bottom: 30px !important;
padding: 20px !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
color: white !important;
border-radius: 16px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important;
}
.title-header h1 {
font-size: 28px !important;
font-weight: 700 !important;
margin: 0 !important;
text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
}
"""
async def infer(
prompt,
seed,
randomize_seed,
width,
height,
cfg,
steps,
model_name,
progress=gr.Progress(track_tqdm=True),
):
if not prompt.strip():
raise gr.Error("提示词不能为空")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# 验证并调整尺寸
width, height = validate_dimensions(width, height)
# 验证其他参数
if not 1.0 <= cfg <= 20.0:
raise gr.Error("CFG Scale 必须在 1.0 到 20.0 之间")
if not 1 <= steps <= 50:
raise gr.Error("Steps 必须在 1 到 50 之间")
image_url, error = await image_client.generate_image(
prompt=prompt,
negative_prompt="",
seed=seed,
width=width,
height=height,
cfg=cfg,
steps=steps,
model_name=model_name
)
if error:
raise gr.Error(error)
return image_url, seed
with gr.Blocks(css=css) as demo:
with gr.Column(elem_classes=["main-container"]):
# 标题区域
with gr.Row(elem_classes=["title-header"]):
gr.HTML("<h1>🎨 Lumina Text-to-Image Playground</h1>")
# 主体左右布局
with gr.Row(equal_height=True):
# 左侧参数设置面板
with gr.Column(scale=1, elem_classes=["left-panel"]):
gr.HTML("<div class='settings-title'>✨ Generation Settings</div>")
# 主要提示词输入
prompt = gr.Text(
label="Prompt",
show_label=True,
max_lines=10,
lines=8,
placeholder="Describe what you want to generate in detail...",
elem_id="main-prompt",
)
run_button = gr.Button("🚀 Generate Image", elem_classes=["run-button"], variant="primary")
# 参数设置区域
with gr.Column(elem_classes=["settings-section"]):
gr.HTML("<div class='settings-title'>🎛️ Parameters</div>")
# 种子设置
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
elem_classes=["slider-container"]
)
randomize_seed = gr.Checkbox(label="Random Seed", value=True)
# 尺寸设置
gr.HTML("<div style='margin: 16px 0 8px 0; font-weight: 600; color: #6B7280;'>📐 Image Dimensions</div>")
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
elem_classes=["slider-container"]
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
elem_classes=["slider-container"]
)
# 生成参数
gr.HTML("<div style='margin: 16px 0 8px 0; font-weight: 600; color: #6B7280;'>⚙️ Generation Parameters</div>")
with gr.Row():
cfg = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=20.0,
step=0.1,
value=5.5,
elem_classes=["slider-container"]
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
step=1,
value=30,
elem_classes=["slider-container"]
)
# 模型选择
model_name = gr.Dropdown(
label="🤖 Model Selection",
choices=list(MODEL_CONFIGS.keys()),
value="base"
)
# 示例提示词
with gr.Column(elem_classes=["examples-container"]):
gr.HTML("<div class='settings-title'>💡 Example Prompts</div>")
gr.Examples(examples=examples, inputs=[prompt])
# 右侧图像显示面板
with gr.Column(scale=1, elem_classes=["right-panel"]):
gr.HTML("<div class='settings-title'>🖼️ Generated Image</div>")
result = gr.Image(
label="Result",
show_label=False,
elem_classes=["result-image"],
height=600,
container=True,
show_download_button=True,
show_share_button=False
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
cfg,
steps,
model_name,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()