|
import gradio as gr |
|
import os |
|
import random |
|
import httpx |
|
import asyncio |
|
from dataclasses import dataclass, field |
|
from typing import Any |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HTTP_STATUS_CENSORED = 451 |
|
HTTP_STATUS_OK = 200 |
|
MAX_SEED = 2147483647 |
|
MAX_IMAGE_SIZE = 2048 |
|
MIN_IMAGE_SIZE = 256 |
|
|
|
|
|
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true" |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"base": "results_cosine_2e-4_bs64_infallssssuum/checkpoint-e4_s82000/consolidated.00-of-01.pth", |
|
"aesthetics_fixed": "results_cosine_2e-4_bs64_infallssssuumnnnnaaa/checkpoint-e54_s43058/consolidated.00-of-01.pth", |
|
"ep3": "results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s63858/consolidated.00-of-01.pth", |
|
"ep3latest": "autodl-fs/lumina_hf/results_cosine_2e-4_bs64_infallssssuumnnnn/checkpoint-e3_s83000/consolidated.00-of-01.pth" |
|
} |
|
|
|
def validate_dimensions(width: int, height: int) -> tuple[int, int]: |
|
"""验证并调整图片尺寸""" |
|
|
|
width = max(MIN_IMAGE_SIZE, min(width, MAX_IMAGE_SIZE)) |
|
height = max(MIN_IMAGE_SIZE, min(height, MAX_IMAGE_SIZE)) |
|
width = (width // 32) * 32 |
|
height = (height // 32) * 32 |
|
return width, height |
|
|
|
@dataclass |
|
class LuminaConfig: |
|
"""Lumina模型配置""" |
|
model_name: str | None = None |
|
cfg: float | None = None |
|
step: int | None = None |
|
|
|
@dataclass |
|
class ImageGenerationConfig: |
|
"""图像生成配置""" |
|
prompts: list[dict[str, Any]] = field(default_factory=list) |
|
width: int = 1024 |
|
height: int = 1024 |
|
seed: int | None = None |
|
use_polish: bool = False |
|
is_lumina: bool = True |
|
lumina_config: LuminaConfig = field(default_factory=LuminaConfig) |
|
|
|
class ImageClient: |
|
"""图像生成客户端""" |
|
def __init__(self) -> None: |
|
|
|
self.x_token = os.environ.get("API_TOKEN", "") |
|
if not self.x_token: |
|
raise ValueError("环境变量中未设置API_TOKEN") |
|
|
|
|
|
self.lumina_api_url = "https://ops.api.talesofai.cn/v3/make_image" |
|
self.lumina_task_status_url = "https://ops.api.talesofai.cn/v1/artifact/task/{task_uuid}" |
|
|
|
|
|
self.max_polling_attempts = 100 |
|
self.polling_interval = 3.0 |
|
|
|
|
|
|
|
self.default_headers = { |
|
"Content-Type": "application/json", |
|
"x-platform": "nieta-app/web", |
|
"X-Token": self.x_token, |
|
} |
|
|
|
def _prepare_prompt_data(self, prompt: str, negative_prompt: str = "") -> list[dict[str, Any]]: |
|
"""准备提示词数据""" |
|
prompts = [ |
|
{ |
|
"type": "freetext", |
|
"value": prompt, |
|
"weight": 1.0 |
|
} |
|
] |
|
|
|
if negative_prompt: |
|
prompts.append({ |
|
"type": "freetext", |
|
"value": negative_prompt, |
|
"weight": -1.0 |
|
}) |
|
|
|
|
|
prompts.append({ |
|
"type": "elementum", |
|
"value": "b5edccfe-46a2-4a14-a8ff-f4d430343805", |
|
"uuid": "b5edccfe-46a2-4a14-a8ff-f4d430343805", |
|
"weight": 1.0, |
|
"name": "lumina1", |
|
"img_url": "https://oss.talesofai.cn/picture_s/1y7f53e6itfn_0.jpeg", |
|
"domain": "", |
|
"parent": "", |
|
"label": None, |
|
"sort_index": 0, |
|
"status": "IN_USE", |
|
"polymorphi_values": {}, |
|
"sub_type": None, |
|
}) |
|
|
|
return prompts |
|
|
|
def _build_payload(self, config: ImageGenerationConfig) -> dict[str, Any]: |
|
"""构建API请求载荷""" |
|
payload = { |
|
"storyId": "", |
|
"jobType": "universal", |
|
"width": config.width, |
|
"height": config.height, |
|
"rawPrompt": config.prompts, |
|
"seed": config.seed, |
|
"meta": {"entrance": "PICTURE,PURE"}, |
|
"context_model_series": None, |
|
"negative_freetext": "", |
|
"advanced_translator": config.use_polish, |
|
} |
|
|
|
if config.is_lumina: |
|
client_args = {} |
|
if config.lumina_config.model_name: |
|
client_args["ckpt_name"] = config.lumina_config.model_name |
|
if config.lumina_config.cfg is not None: |
|
client_args["cfg"] = str(config.lumina_config.cfg) |
|
if config.lumina_config.step is not None: |
|
client_args["steps"] = str(config.lumina_config.step) |
|
|
|
if client_args: |
|
payload["client_args"] = client_args |
|
|
|
return payload |
|
|
|
async def _poll_task_status(self, task_uuid: str) -> dict[str, Any]: |
|
"""轮询任务状态""" |
|
async with httpx.AsyncClient(timeout=30.0) as client: |
|
for _ in range(self.max_polling_attempts): |
|
response = await client.get( |
|
self.lumina_task_status_url.format(task_uuid=task_uuid), |
|
headers=self.default_headers |
|
) |
|
|
|
if response.status_code != HTTP_STATUS_OK: |
|
return { |
|
"success": False, |
|
"error": f"获取任务状态失败: {response.status_code} - {response.text}" |
|
} |
|
|
|
|
|
try: |
|
result = response.json() |
|
except Exception as e: |
|
return { |
|
"success": False, |
|
"error": f"任务状态响应解析失败: {response.text[:500]}" |
|
} |
|
|
|
|
|
task_status = result.get("task_status") |
|
|
|
if task_status == "SUCCESS": |
|
|
|
artifacts = result.get("artifacts", []) |
|
if artifacts and len(artifacts) > 0: |
|
image_url = artifacts[0].get("url") |
|
if image_url: |
|
return { |
|
"success": True, |
|
"image_url": image_url |
|
} |
|
return { |
|
"success": False, |
|
"error": "无法从结果中提取图像URL" |
|
} |
|
elif task_status == "FAILURE": |
|
return { |
|
"success": False, |
|
"error": result.get("error", "任务执行失败") |
|
} |
|
elif task_status == "ILLEGAL_IMAGE": |
|
return { |
|
"success": False, |
|
"error": "图片不合规" |
|
} |
|
elif task_status == "TIMEOUT": |
|
return { |
|
"success": False, |
|
"error": "任务超时" |
|
} |
|
|
|
await asyncio.sleep(self.polling_interval) |
|
|
|
return { |
|
"success": False, |
|
"error": "⏳ 生图任务超时(5分钟),服务器可能正在处理大量请求,请稍后重试" |
|
} |
|
|
|
async def generate_image(self, prompt: str, negative_prompt: str, seed: int, width: int, height: int, cfg: float, steps: int, model_name: str = "base") -> tuple[str | None, str | None]: |
|
"""生成图片""" |
|
try: |
|
|
|
model_path = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["base"]) |
|
|
|
|
|
config = ImageGenerationConfig( |
|
prompts=self._prepare_prompt_data(prompt, negative_prompt), |
|
width=width, |
|
height=height, |
|
seed=seed, |
|
is_lumina=True, |
|
lumina_config=LuminaConfig( |
|
model_name=model_path, |
|
cfg=cfg, |
|
step=steps |
|
) |
|
) |
|
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client: |
|
payload = self._build_payload(config) |
|
if DEBUG_MODE: |
|
print(f"DEBUG: 发送API请求到 {self.lumina_api_url}") |
|
print(f"DEBUG: 请求载荷: {payload}") |
|
|
|
response = await client.post( |
|
self.lumina_api_url, |
|
json=payload, |
|
headers=self.default_headers |
|
) |
|
|
|
if DEBUG_MODE: |
|
print(f"DEBUG: API响应状态码: {response.status_code}") |
|
print(f"DEBUG: API响应内容: {response.text[:1000]}") |
|
|
|
if response.status_code == HTTP_STATUS_CENSORED: |
|
return None, "内容不合规" |
|
|
|
|
|
if response.status_code == 433: |
|
return None, "⏳ 服务器正忙,同时生成的图片数量已达上限,请稍后重试" |
|
|
|
if response.status_code != HTTP_STATUS_OK: |
|
return None, f"API请求失败: {response.status_code} - {response.text}" |
|
|
|
|
|
content = response.text.strip() |
|
task_uuid = content.replace('"', "") |
|
|
|
if DEBUG_MODE: |
|
print(f"DEBUG: API返回UUID: {task_uuid}") |
|
|
|
if not task_uuid: |
|
return None, f"未获取到任务ID,API响应: {response.text}" |
|
|
|
|
|
result = await self._poll_task_status(task_uuid) |
|
if result["success"]: |
|
return result["image_url"], None |
|
else: |
|
return None, result["error"] |
|
|
|
except Exception as e: |
|
return None, f"生成图片时发生错误: {str(e)}" |
|
|
|
|
|
image_client = ImageClient() |
|
|
|
|
|
examples = [ |
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", |
|
"An astronaut riding a green horse", |
|
"A delicious ceviche cheesecake slice", |
|
] |
|
|
|
css = """ |
|
.main-container { |
|
max-width: 1400px !important; |
|
margin: 0 auto !important; |
|
padding: 20px !important; |
|
} |
|
|
|
.left-panel { |
|
background: linear-gradient(145deg, #f8f9fa, #e9ecef) !important; |
|
border-radius: 16px !important; |
|
padding: 24px !important; |
|
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; |
|
border: 1px solid rgba(255,255,255,0.2) !important; |
|
} |
|
|
|
.right-panel { |
|
background: linear-gradient(145deg, #ffffff, #f8f9fa) !important; |
|
border-radius: 16px !important; |
|
padding: 24px !important; |
|
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; |
|
border: 1px solid rgba(255,255,255,0.2) !important; |
|
display: flex !important; |
|
flex-direction: column !important; |
|
align-items: center !important; |
|
justify-content: flex-start !important; |
|
min-height: 600px !important; |
|
} |
|
|
|
#main-prompt textarea { |
|
min-height: 180px !important; |
|
font-size: 15px !important; |
|
line-height: 1.6 !important; |
|
padding: 16px !important; |
|
border-radius: 12px !important; |
|
border: 2px solid #e9ecef !important; |
|
transition: all 0.3s ease !important; |
|
background: white !important; |
|
} |
|
|
|
#main-prompt textarea:focus { |
|
border-color: #4f46e5 !important; |
|
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1) !important; |
|
} |
|
|
|
.run-button { |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
|
border: none !important; |
|
border-radius: 12px !important; |
|
padding: 12px 32px !important; |
|
font-weight: 600 !important; |
|
font-size: 16px !important; |
|
color: white !important; |
|
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; |
|
transition: all 0.3s ease !important; |
|
} |
|
|
|
.run-button:hover { |
|
transform: translateY(-2px) !important; |
|
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important; |
|
} |
|
|
|
.settings-section { |
|
background: white !important; |
|
border-radius: 12px !important; |
|
padding: 16px !important; |
|
margin-top: 12px !important; |
|
box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important; |
|
border: 1px solid #e9ecef !important; |
|
} |
|
|
|
.settings-title { |
|
font-size: 18px !important; |
|
font-weight: 600 !important; |
|
color: #374151 !important; |
|
margin-bottom: 12px !important; |
|
padding-bottom: 6px !important; |
|
border-bottom: 2px solid #e9ecef !important; |
|
} |
|
|
|
.right-panel .settings-title { |
|
margin-bottom: 8px !important; |
|
padding-bottom: 4px !important; |
|
font-size: 16px !important; |
|
} |
|
|
|
.right-panel .block { |
|
min-height: unset !important; |
|
height: auto !important; |
|
flex: none !important; |
|
} |
|
|
|
.right-panel .html-container { |
|
padding: 0 !important; |
|
margin: 0 !important; |
|
} |
|
|
|
.slider-container .wrap { |
|
background: #f8f9fa !important; |
|
border-radius: 8px !important; |
|
padding: 6px !important; |
|
margin: 2px 0 !important; |
|
} |
|
|
|
.settings-section .block { |
|
margin: 4px 0 !important; |
|
} |
|
|
|
.settings-section .row { |
|
margin: 6px 0 !important; |
|
} |
|
|
|
.settings-section .form { |
|
gap: 4px !important; |
|
} |
|
|
|
.settings-section .html-container { |
|
padding: 0 !important; |
|
margin: 8px 0 4px 0 !important; |
|
} |
|
|
|
.result-image { |
|
border-radius: 16px !important; |
|
box-shadow: 0 8px 32px rgba(0,0,0,0.15) !important; |
|
max-width: 100% !important; |
|
height: auto !important; |
|
min-height: 400px !important; |
|
width: 100% !important; |
|
} |
|
|
|
.result-image img { |
|
border-radius: 16px !important; |
|
object-fit: contain !important; |
|
max-width: 100% !important; |
|
max-height: 600px !important; |
|
width: auto !important; |
|
height: auto !important; |
|
} |
|
|
|
.examples-container { |
|
margin-top: 20px !important; |
|
} |
|
|
|
.title-header { |
|
text-align: center !important; |
|
margin-bottom: 30px !important; |
|
padding: 20px !important; |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
|
color: white !important; |
|
border-radius: 16px !important; |
|
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3) !important; |
|
} |
|
|
|
.title-header h1 { |
|
font-size: 28px !important; |
|
font-weight: 700 !important; |
|
margin: 0 !important; |
|
text-shadow: 0 2px 4px rgba(0,0,0,0.1) !important; |
|
} |
|
""" |
|
|
|
async def infer( |
|
prompt, |
|
seed, |
|
randomize_seed, |
|
width, |
|
height, |
|
cfg, |
|
steps, |
|
model_name, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
if not prompt.strip(): |
|
raise gr.Error("提示词不能为空") |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
width, height = validate_dimensions(width, height) |
|
|
|
|
|
if not 1.0 <= cfg <= 20.0: |
|
raise gr.Error("CFG Scale 必须在 1.0 到 20.0 之间") |
|
if not 1 <= steps <= 50: |
|
raise gr.Error("Steps 必须在 1 到 50 之间") |
|
|
|
image_url, error = await image_client.generate_image( |
|
prompt=prompt, |
|
negative_prompt="", |
|
seed=seed, |
|
width=width, |
|
height=height, |
|
cfg=cfg, |
|
steps=steps, |
|
model_name=model_name |
|
) |
|
|
|
if error: |
|
raise gr.Error(error) |
|
|
|
return image_url, seed |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_classes=["main-container"]): |
|
|
|
with gr.Row(elem_classes=["title-header"]): |
|
gr.HTML("<h1>🎨 Lumina Text-to-Image Playground</h1>") |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
|
with gr.Column(scale=1, elem_classes=["left-panel"]): |
|
gr.HTML("<div class='settings-title'>✨ Generation Settings</div>") |
|
|
|
|
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=True, |
|
max_lines=10, |
|
lines=8, |
|
placeholder="Describe what you want to generate in detail...", |
|
elem_id="main-prompt", |
|
) |
|
|
|
run_button = gr.Button("🚀 Generate Image", elem_classes=["run-button"], variant="primary") |
|
|
|
|
|
with gr.Column(elem_classes=["settings-section"]): |
|
gr.HTML("<div class='settings-title'>🎛️ Parameters</div>") |
|
|
|
|
|
with gr.Row(): |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0, |
|
elem_classes=["slider-container"] |
|
) |
|
randomize_seed = gr.Checkbox(label="Random Seed", value=True) |
|
|
|
|
|
gr.HTML("<div style='margin: 16px 0 8px 0; font-weight: 600; color: #6B7280;'>📐 Image Dimensions</div>") |
|
with gr.Row(): |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
elem_classes=["slider-container"] |
|
) |
|
height = gr.Slider( |
|
label="Height", |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
elem_classes=["slider-container"] |
|
) |
|
|
|
|
|
gr.HTML("<div style='margin: 16px 0 8px 0; font-weight: 600; color: #6B7280;'>⚙️ Generation Parameters</div>") |
|
with gr.Row(): |
|
cfg = gr.Slider( |
|
label="CFG Scale", |
|
minimum=1.0, |
|
maximum=20.0, |
|
step=0.1, |
|
value=5.5, |
|
elem_classes=["slider-container"] |
|
) |
|
steps = gr.Slider( |
|
label="Steps", |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=30, |
|
elem_classes=["slider-container"] |
|
) |
|
|
|
|
|
model_name = gr.Dropdown( |
|
label="🤖 Model Selection", |
|
choices=list(MODEL_CONFIGS.keys()), |
|
value="base" |
|
) |
|
|
|
|
|
with gr.Column(elem_classes=["examples-container"]): |
|
gr.HTML("<div class='settings-title'>💡 Example Prompts</div>") |
|
gr.Examples(examples=examples, inputs=[prompt]) |
|
|
|
|
|
with gr.Column(scale=1, elem_classes=["right-panel"]): |
|
gr.HTML("<div class='settings-title'>🖼️ Generated Image</div>") |
|
result = gr.Image( |
|
label="Result", |
|
show_label=False, |
|
elem_classes=["result-image"], |
|
height=600, |
|
container=True, |
|
show_download_button=True, |
|
show_share_button=False |
|
) |
|
|
|
gr.on( |
|
triggers=[run_button.click, prompt.submit], |
|
fn=infer, |
|
inputs=[ |
|
prompt, |
|
seed, |
|
randomize_seed, |
|
width, |
|
height, |
|
cfg, |
|
steps, |
|
model_name, |
|
], |
|
outputs=[result, seed], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|