Xsong123 commited on
Commit
ff12ba7
·
verified ·
1 Parent(s): 4d0c4ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -40
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import numpy as np
3
- import spaces
4
  import torch
5
  import random
6
  import json
@@ -13,10 +13,14 @@ from safetensors.torch import load_file
13
  import requests
14
  import re
15
 
16
- # Load Kontext model
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
- pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
 
 
 
 
20
 
21
  # Load LoRA data from our custom JSON file
22
  with open("kontext_loras.json", "r") as file:
@@ -28,12 +32,12 @@ with open("kontext_loras.json", "r") as file:
28
  "title": item["title"],
29
  "repo": item["repo"],
30
  "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
 
31
  # The following keys are kept for compatibility with the original demo structure,
32
  # but our simplified logic doesn't heavily rely on them.
33
- "trigger_word": item.get("trigger_word", ""),
34
  "lora_type": item.get("lora_type", "flux"),
35
  "lora_scale_config": item.get("lora_scale", 1.0), # Default scale set to 1.0
36
- "prompt_placeholder": item.get("prompt_placeholder", "Describe the subject..."),
37
  }
38
  for item in data
39
  ]
@@ -44,30 +48,30 @@ def update_selection(selected_state: gr.SelectData, flux_loras):
44
  if selected_state.index >= len(flux_loras):
45
  return "### No LoRA selected", gr.update(), None, gr.update()
46
 
47
- lora_repo = flux_loras[selected_state.index]["repo"]
 
 
48
 
49
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
50
- config_placeholder = flux_loras[selected_state.index]["prompt_placeholder"]
51
 
52
- optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.0)
53
- print("Selected Style: ", flux_loras[selected_state.index]['title'])
54
  print("Optimal Scale: ", optimal_scale)
55
- return updated_text, gr.update(placeholder=config_placeholder), selected_state.index, optimal_scale
56
 
57
  # This wrapper is kept for compatibility with the Gradio event triggers
58
- def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
59
  """Wrapper function to handle state serialization"""
60
  # The 'custom_lora' and 'lora_state' arguments are no longer used but kept in the signature
61
- return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
62
 
63
- @spaces.GPU
64
- def infer_with_lora(input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
65
  """Generate image with selected LoRA"""
66
  global pipe
67
 
68
- if randomize_seed:
69
- seed = random.randint(0, MAX_SEED)
70
-
71
  # Unload any previous LoRA to ensure a clean state
72
  if "selected_lora" in pipe.get_active_adapters():
73
  pipe.unload_lora_weights()
@@ -89,20 +93,12 @@ def infer_with_lora(input_image, prompt, selected_index, seed=42, randomize_seed
89
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
90
  print(f"Loaded {lora_to_use['repo']} with scale {lora_scale}")
91
 
92
- # Simplified and direct prompt construction
93
- style_name = lora_to_use['title']
94
- if prompt:
95
- final_prompt = f"Turn this image of {prompt} into {style_name} style."
96
- else:
97
- final_prompt = f"Turn this image into {style_name} style."
98
- print(f"Using prompt: {final_prompt}")
99
-
100
  except Exception as e:
101
  print(f"Error loading LoRA: {e}")
102
- final_prompt = prompt # Fallback to user prompt if LoRA fails
103
- else:
104
- # No LoRA selected, just use the original prompt
105
- final_prompt = prompt
106
 
107
  input_image = input_image.convert("RGB")
108
 
@@ -113,14 +109,17 @@ def infer_with_lora(input_image, prompt, selected_index, seed=42, randomize_seed
113
  height=input_image.size[1],
114
  prompt=final_prompt,
115
  guidance_scale=guidance_scale,
 
116
  generator=torch.Generator().manual_seed(seed)
117
  ).images[0]
118
 
119
- return image, seed, gr.update(visible=True), lora_scale
 
120
 
121
  except Exception as e:
122
  print(f"Error during inference: {e}")
123
- return None, seed, gr.update(visible=False), lora_scale
 
124
 
125
  # CSS styling
126
  css = """
@@ -159,6 +158,10 @@ css = """
159
  #gallery{
160
  overflow: scroll !important
161
  }
 
 
 
 
162
  """
163
 
164
  # Create Gradio interface
@@ -169,7 +172,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
169
  """<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> Kontext-Style LoRA Explorer</h1>""",
170
  elem_id="title",
171
  )
172
- gr.Markdown("A demo for the style LoRAs from the [Kontext-Style Collection](https://huggingface.co/Kontext-Style) 🤗")
173
 
174
  selected_state = gr.State(value=None)
175
  # The following states are no longer used by the simplified logic but kept for component structure
@@ -179,15 +182,20 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
179
  with gr.Row(elem_id="main_app"):
180
  with gr.Column(scale=4, elem_id="box_column"):
181
  with gr.Group(elem_id="gallery_box"):
182
- input_image = gr.Image(label="Upload a picture of yourself", type="pil", height=300)
183
- portrait_mode = gr.Checkbox(label="portrait mode", value=True)
 
 
 
 
184
  gallery = gr.Gallery(
185
  label="Pick a LoRA",
186
  allow_preview=False,
187
  columns=3,
188
  elem_id="gallery",
189
  show_share_button=False,
190
- height=400
 
191
  )
192
 
193
  custom_model = gr.Textbox(
@@ -219,7 +227,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
219
  minimum=0,
220
  maximum=2,
221
  step=0.1,
222
- value=1.5,
223
  info="Controls the strength of the LoRA effect"
224
  )
225
  seed = gr.Slider(
@@ -229,7 +237,6 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
229
  step=1,
230
  value=0,
231
  )
232
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
233
  guidance_scale = gr.Slider(
234
  label="Guidance Scale",
235
  minimum=1,
@@ -237,6 +244,14 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
237
  step=0.1,
238
  value=2.5,
239
  )
 
 
 
 
 
 
 
 
240
 
241
  prompt_title = gr.Markdown(
242
  value="### Click on a LoRA in the gallery to select it",
@@ -257,8 +272,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
257
  gr.on(
258
  triggers=[run_button.click, prompt.submit],
259
  fn=infer_with_lora_wrapper,
260
- inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
261
- outputs=[result, seed, reuse_button, lora_state]
262
  )
263
 
264
  reuse_button.click(
@@ -269,7 +284,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
269
 
270
  # Initialize gallery
271
  demo.load(
272
- fn=lambda: (flux_loras_raw, flux_loras_raw),
 
273
  outputs=[gallery, gr_flux_loras]
274
  )
275
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import spaces # This is a special module for Hugging Face Spaces, not needed for local execution
4
  import torch
5
  import random
6
  import json
 
13
  import requests
14
  import re
15
 
16
+ # Load Kontext model from your local path
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
+ # Use the local path for the base model as in your test.py
20
+ pipe = FluxKontextPipeline.from_pretrained(
21
+ "/hpc2hdd/home/sfei285/Project/Editing/FLUX.1-Kontext-dev",
22
+ torch_dtype=torch.bfloat16
23
+ ).to("cuda")
24
 
25
  # Load LoRA data from our custom JSON file
26
  with open("kontext_loras.json", "r") as file:
 
32
  "title": item["title"],
33
  "repo": item["repo"],
34
  "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
35
+ "prompt": item.get("prompt", f"Turn this image into {item['title']} style."),
36
  # The following keys are kept for compatibility with the original demo structure,
37
  # but our simplified logic doesn't heavily rely on them.
 
38
  "lora_type": item.get("lora_type", "flux"),
39
  "lora_scale_config": item.get("lora_scale", 1.0), # Default scale set to 1.0
40
+ "prompt_placeholder": item.get("prompt_placeholder", "You can edit the prompt here..."),
41
  }
42
  for item in data
43
  ]
 
48
  if selected_state.index >= len(flux_loras):
49
  return "### No LoRA selected", gr.update(), None, gr.update()
50
 
51
+ selected_lora = flux_loras[selected_state.index]
52
+ lora_repo = selected_lora["repo"]
53
+ default_prompt = selected_lora.get("prompt")
54
 
55
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
 
56
 
57
+ optimal_scale = selected_lora.get("lora_scale_config", 1.0)
58
+ print("Selected Style: ", selected_lora['title'])
59
  print("Optimal Scale: ", optimal_scale)
60
+ return updated_text, gr.update(value=default_prompt), selected_state.index, optimal_scale
61
 
62
  # This wrapper is kept for compatibility with the Gradio event triggers
63
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=0, guidance_scale=2.5, num_inference_steps=28, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
64
  """Wrapper function to handle state serialization"""
65
  # The 'custom_lora' and 'lora_state' arguments are no longer used but kept in the signature
66
+ return infer_with_lora(input_image, prompt, selected_index, seed, guidance_scale, num_inference_steps, lora_scale, flux_loras, progress)
67
 
68
+ @spaces.GPU # This decorator is only for Hugging Face Spaces hardware, not needed for local execution
69
+ def infer_with_lora(input_image, prompt, selected_index, seed=0, guidance_scale=2.5, num_inference_steps=28, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
70
  """Generate image with selected LoRA"""
71
  global pipe
72
 
73
+ # The seed is now always taken directly from the input. Randomization has been removed.
74
+
 
75
  # Unload any previous LoRA to ensure a clean state
76
  if "selected_lora" in pipe.get_active_adapters():
77
  pipe.unload_lora_weights()
 
93
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
94
  print(f"Loaded {lora_to_use['repo']} with scale {lora_scale}")
95
 
 
 
 
 
 
 
 
 
96
  except Exception as e:
97
  print(f"Error loading LoRA: {e}")
98
+
99
+ # Use the prompt from the textbox directly.
100
+ final_prompt = prompt
101
+ print(f"Using prompt: {final_prompt}")
102
 
103
  input_image = input_image.convert("RGB")
104
 
 
109
  height=input_image.size[1],
110
  prompt=final_prompt,
111
  guidance_scale=guidance_scale,
112
+ num_inference_steps=num_inference_steps,
113
  generator=torch.Generator().manual_seed(seed)
114
  ).images[0]
115
 
116
+ # The seed value is no longer returned, as it's not being changed.
117
+ return image, gr.update(visible=True), lora_scale
118
 
119
  except Exception as e:
120
  print(f"Error during inference: {e}")
121
+ # Return an error state for all outputs
122
+ return None, gr.update(visible=False), lora_scale
123
 
124
  # CSS styling
125
  css = """
 
158
  #gallery{
159
  overflow: scroll !important
160
  }
161
+ /* Custom CSS to ensure the input image is fully visible */
162
+ #input_image_display div[data-testid="image"] img {
163
+ object-fit: contain !important;
164
+ }
165
  """
166
 
167
  # Create Gradio interface
 
172
  """<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> Kontext-Style LoRA Explorer</h1>""",
173
  elem_id="title",
174
  )
175
+ gr.Markdown("A demo for the style LoRAs from the [Kontext-Style](https://huggingface.co/Kontext-Style) 🤗")
176
 
177
  selected_state = gr.State(value=None)
178
  # The following states are no longer used by the simplified logic but kept for component structure
 
182
  with gr.Row(elem_id="main_app"):
183
  with gr.Column(scale=4, elem_id="box_column"):
184
  with gr.Group(elem_id="gallery_box"):
185
+ input_image = gr.Image(
186
+ label="Upload a picture of yourself",
187
+ type="pil",
188
+ height=300,
189
+ elem_id="input_image_display"
190
+ )
191
  gallery = gr.Gallery(
192
  label="Pick a LoRA",
193
  allow_preview=False,
194
  columns=3,
195
  elem_id="gallery",
196
  show_share_button=False,
197
+ height=400,
198
+ object_fit="contain"
199
  )
200
 
201
  custom_model = gr.Textbox(
 
227
  minimum=0,
228
  maximum=2,
229
  step=0.1,
230
+ value=1.0,
231
  info="Controls the strength of the LoRA effect"
232
  )
233
  seed = gr.Slider(
 
237
  step=1,
238
  value=0,
239
  )
 
240
  guidance_scale = gr.Slider(
241
  label="Guidance Scale",
242
  minimum=1,
 
244
  step=0.1,
245
  value=2.5,
246
  )
247
+ num_inference_steps = gr.Slider(
248
+ label="Timesteps",
249
+ minimum=1,
250
+ maximum=100,
251
+ step=1,
252
+ value=28,
253
+ info="Number of inference steps"
254
+ )
255
 
256
  prompt_title = gr.Markdown(
257
  value="### Click on a LoRA in the gallery to select it",
 
272
  gr.on(
273
  triggers=[run_button.click, prompt.submit],
274
  fn=infer_with_lora_wrapper,
275
+ inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, guidance_scale, num_inference_steps, lora_scale, gr_flux_loras],
276
+ outputs=[result, reuse_button, lora_state]
277
  )
278
 
279
  reuse_button.click(
 
284
 
285
  # Initialize gallery
286
  demo.load(
287
+ fn=lambda loras: ([(item["image"], item["title"]) for item in loras], loras),
288
+ inputs=[gr_flux_loras],
289
  outputs=[gallery, gr_flux_loras]
290
  )
291