Xsong123 commited on
Commit
3cdae16
·
verified ·
1 Parent(s): 41d5780

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -0
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import random
6
+ import json
7
+ import os
8
+ from PIL import Image
9
+ from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
+ from safetensors.torch import load_file
13
+ import requests
14
+ import re
15
+
16
+ # Load Kontext model
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+
19
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
+
21
+ # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
22
+
23
+ with open("flux_loras.json", "r") as file:
24
+ data = json.load(file)
25
+ flux_loras_raw = [
26
+ {
27
+ "image": item["image"],
28
+ "title": item["title"],
29
+ "repo": item["repo"],
30
+ "trigger_word": item.get("trigger_word", ""),
31
+ "trigger_position": item.get("trigger_position", "prepend"),
32
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
33
+ "lora_type": item.get("lora_type", "flux"),
34
+ "lora_scale_config": item.get("lora_scale", 1.5),
35
+ "prompt_placeholder": item.get("prompt_placeholder", ""),
36
+ }
37
+ for item in data
38
+ ]
39
+ print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
40
+ # Global variables for LoRA management
41
+ lora_cache = {}
42
+
43
+ def load_lora_weights(repo_id, weights_filename):
44
+ """Load LoRA weights from HuggingFace"""
45
+ try:
46
+ if repo_id not in lora_cache:
47
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
48
+ lora_cache[repo_id] = lora_path
49
+ return lora_cache[repo_id]
50
+ except Exception as e:
51
+ print(f"Error loading LoRA from {repo_id}: {e}")
52
+ return None
53
+
54
+ def update_selection(selected_state: gr.SelectData, flux_loras):
55
+ """Update UI when a LoRA is selected"""
56
+ if selected_state.index >= len(flux_loras):
57
+ return "### No LoRA selected", gr.update(), None, gr.update()
58
+
59
+ lora_repo = flux_loras[selected_state.index]["repo"]
60
+ trigger_word = flux_loras[selected_state.index]["trigger_word"]
61
+
62
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
63
+ config_placeholder = flux_loras[selected_state.index]["prompt_placeholder"]
64
+ if config_placeholder:
65
+ new_placeholder = config_placeholder
66
+ else:
67
+ new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
68
+
69
+ print("Selected Index: ", flux_loras[selected_state.index])
70
+ optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.5)
71
+ print("Optimal Scale: ", optimal_scale)
72
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, optimal_scale
73
+
74
+
75
+ def get_huggingface_lora(link):
76
+ """Download LoRA from HuggingFace link"""
77
+ split_link = link.split("/")
78
+ if len(split_link) == 2:
79
+ try:
80
+ model_card = ModelCard.load(link)
81
+ trigger_word = model_card.data.get("instance_prompt", "")
82
+
83
+ fs = HfFileSystem()
84
+ list_of_files = fs.ls(link, detail=False)
85
+ safetensors_file = None
86
+
87
+ for file in list_of_files:
88
+ if file.endswith(".safetensors") and "lora" in file.lower():
89
+ safetensors_file = file.split("/")[-1]
90
+ break
91
+
92
+ if not safetensors_file:
93
+ safetensors_file = "pytorch_lora_weights.safetensors"
94
+
95
+ return split_link[1], safetensors_file, trigger_word
96
+ except Exception as e:
97
+ raise Exception(f"Error loading LoRA: {e}")
98
+ else:
99
+ raise Exception("Invalid HuggingFace repository format")
100
+
101
+ def load_custom_lora(link):
102
+ """Load custom LoRA from user input"""
103
+ if not link:
104
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
105
+
106
+ try:
107
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
108
+
109
+ card = f'''
110
+ <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
111
+ <span><strong>Loaded custom LoRA:</strong></span>
112
+ <div style="margin-top: 8px;">
113
+ <h4>{repo_name}</h4>
114
+ <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
115
+ </div>
116
+ </div>
117
+ '''
118
+
119
+ custom_lora_data = {
120
+ "repo": link,
121
+ "weights": weights_file,
122
+ "trigger_word": trigger_word
123
+ }
124
+
125
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
126
+
127
+ except Exception as e:
128
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
129
+
130
+ def remove_custom_lora():
131
+ """Remove custom LoRA"""
132
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
133
+
134
+ def classify_gallery(flux_loras):
135
+ """Sort gallery by likes"""
136
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
137
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
138
+
139
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
140
+ """Wrapper function to handle state serialization"""
141
+ return infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
142
+
143
+ @spaces.GPU
144
+ def infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
145
+ """Generate image with selected LoRA"""
146
+ global pipe
147
+
148
+ if randomize_seed:
149
+ seed = random.randint(0, MAX_SEED)
150
+
151
+ # Determine which LoRA to use
152
+ lora_to_use = None
153
+ if custom_lora:
154
+ lora_to_use = custom_lora
155
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
156
+ lora_to_use = flux_loras[selected_index]
157
+ print(f"Loaded {len(flux_loras)} LoRAs from JSON")
158
+ # Load LoRA if needed
159
+ print(f"LoRA to use: {lora_to_use}")
160
+ if lora_to_use:
161
+ try:
162
+ if "selected_lora" in pipe.get_active_adapters():
163
+ pipe.unload_lora_weights()
164
+
165
+ lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
166
+ if lora_path:
167
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
168
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
169
+ print(f"loaded: {lora_path} with scale {lora_scale}")
170
+
171
+ except Exception as e:
172
+ print(f"Error loading LoRA: {e}")
173
+
174
+
175
+ input_image = input_image.convert("RGB")
176
+ # Add trigger word to prompt
177
+ trigger_word = lora_to_use["trigger_word"]
178
+ is_kontext_lora = lora_to_use["lora_type"] == "kontext"
179
+ if not is_kontext_lora:
180
+ if portrait_mode:
181
+ if trigger_word == ", How2Draw":
182
+ prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
183
+ elif trigger_word == ", video game screenshot in the style of THSMS":
184
+ prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
185
+ else:
186
+ prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
187
+ else:
188
+ if trigger_word == ", How2Draw":
189
+ prompt = f"create a How2Draw sketch of the photo {prompt}"
190
+ elif trigger_word == ", video game screenshot in the style of THSMS":
191
+ prompt = f"create a video game screenshot in the style of THSMS of the photo, {prompt}."
192
+ else:
193
+ prompt = f"convert the style of this photo {prompt} to {trigger_word}."
194
+ else:
195
+ prompt = f"{trigger_word}. {prompt}."
196
+ try:
197
+ image = pipe(
198
+ image=input_image,
199
+ width=input_image.size[0],
200
+ height=input_image.size[1],
201
+ prompt=prompt,
202
+ guidance_scale=guidance_scale,
203
+ generator=torch.Generator().manual_seed(seed)
204
+ ).images[0]
205
+
206
+ return image, seed, gr.update(visible=True), lora_scale
207
+
208
+ except Exception as e:
209
+ print(f"Error during inference: {e}")
210
+ return None, seed, gr.update(visible=False), lora_scale
211
+
212
+ # CSS styling
213
+ css = """
214
+ #main_app {
215
+ display: flex;
216
+ gap: 20px;
217
+ }
218
+ #box_column {
219
+ min-width: 400px;
220
+ }
221
+ #title{text-align: center}
222
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
223
+ #title img{width: 100px; margin-right: 0.5em}
224
+ #selected_lora {
225
+ color: #2563eb;
226
+ font-weight: bold;
227
+ }
228
+ #prompt {
229
+ flex-grow: 1;
230
+ }
231
+ #run_button {
232
+ background: linear-gradient(45deg, #2563eb, #3b82f6);
233
+ color: white;
234
+ border: none;
235
+ padding: 8px 16px;
236
+ border-radius: 6px;
237
+ font-weight: bold;
238
+ }
239
+ .custom_lora_card {
240
+ background: #f8fafc;
241
+ border: 1px solid #e2e8f0;
242
+ border-radius: 8px;
243
+ padding: 12px;
244
+ margin: 8px 0;
245
+ }
246
+ #gallery{
247
+ overflow: scroll !important
248
+ }
249
+ """
250
+
251
+ # Create Gradio interface
252
+ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo:
253
+ gr_flux_loras = gr.State(value=flux_loras_raw)
254
+
255
+ title = gr.HTML(
256
+ """<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> FLUX.1 Kontext LoRA the Explorer</h1>""",
257
+ elem_id="title",
258
+ )
259
+ gr.Markdown("Flux.1 Kontext [dev] + community Kontext & Flux LoRAs 🤗")
260
+
261
+ selected_state = gr.State(value=None)
262
+ custom_loaded_lora = gr.State(value=None)
263
+ lora_state = gr.State(value=1.0)
264
+
265
+ with gr.Row(elem_id="main_app"):
266
+ with gr.Column(scale=4, elem_id="box_column"):
267
+ with gr.Group(elem_id="gallery_box"):
268
+ input_image = gr.Image(label="Upload a picture of yourself", type="pil", height=300)
269
+ portrait_mode = gr.Checkbox(label="portrait mode", value=True)
270
+ gallery = gr.Gallery(
271
+ label="Pick a LoRA",
272
+ allow_preview=False,
273
+ columns=3,
274
+ elem_id="gallery",
275
+ show_share_button=False,
276
+ height=400
277
+ )
278
+
279
+ custom_model = gr.Textbox(
280
+ label="Or enter a custom HuggingFace FLUX LoRA",
281
+ placeholder="e.g., username/lora-name",
282
+ visible=False
283
+ )
284
+ custom_model_card = gr.HTML(visible=False)
285
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
286
+
287
+ with gr.Column(scale=5):
288
+ with gr.Row():
289
+ prompt = gr.Textbox(
290
+ label="Editing Prompt",
291
+ show_label=False,
292
+ lines=1,
293
+ max_lines=1,
294
+ placeholder="opt - describe the person/subject, e.g. 'a man with glasses and a beard'",
295
+ elem_id="prompt"
296
+ )
297
+ run_button = gr.Button("Generate", elem_id="run_button")
298
+
299
+ result = gr.Image(label="Generated Image", interactive=False)
300
+ reuse_button = gr.Button("Reuse this image", visible=False)
301
+
302
+ with gr.Accordion("Advanced Settings", open=False):
303
+ lora_scale = gr.Slider(
304
+ label="LoRA Scale",
305
+ minimum=0,
306
+ maximum=2,
307
+ step=0.1,
308
+ value=1.5,
309
+ info="Controls the strength of the LoRA effect"
310
+ )
311
+ seed = gr.Slider(
312
+ label="Seed",
313
+ minimum=0,
314
+ maximum=MAX_SEED,
315
+ step=1,
316
+ value=0,
317
+ )
318
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
319
+ guidance_scale = gr.Slider(
320
+ label="Guidance Scale",
321
+ minimum=1,
322
+ maximum=10,
323
+ step=0.1,
324
+ value=2.5,
325
+ )
326
+
327
+ prompt_title = gr.Markdown(
328
+ value="### Click on a LoRA in the gallery to select it",
329
+ visible=True,
330
+ elem_id="selected_lora",
331
+ )
332
+
333
+ # Event handlers
334
+ custom_model.input(
335
+ fn=load_custom_lora,
336
+ inputs=[custom_model],
337
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
338
+ )
339
+
340
+ custom_model_button.click(
341
+ fn=remove_custom_lora,
342
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
343
+ )
344
+
345
+ gallery.select(
346
+ fn=update_selection,
347
+ inputs=[gr_flux_loras],
348
+ outputs=[prompt_title, prompt, selected_state, lora_scale],
349
+ show_progress=False
350
+ )
351
+
352
+ gr.on(
353
+ triggers=[run_button.click, prompt.submit],
354
+ fn=infer_with_lora_wrapper,
355
+ inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
356
+ outputs=[result, seed, reuse_button, lora_state]
357
+ )
358
+
359
+ reuse_button.click(
360
+ fn=lambda image: image,
361
+ inputs=[result],
362
+ outputs=[input_image]
363
+ )
364
+
365
+ # Initialize gallery
366
+ demo.load(
367
+ fn=classify_gallery,
368
+ inputs=[gr_flux_loras],
369
+ outputs=[gallery, gr_flux_loras]
370
+ )
371
+
372
+ demo.queue(default_concurrency_limit=None)
373
+ demo.launch()