AlekseyCalvin commited on
Commit
36d4c5e
·
verified ·
1 Parent(s): 5f00555

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ from os import path
8
+ from torchvision import transforms
9
+ from dataclasses import dataclass
10
+ import math
11
+ from typing import Callable
12
+ import spaces
13
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
14
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
15
+ from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPConfig, T5EncoderModel, T5Tokenizer
16
+ from diffusers.models.transformers import FluxTransformer2DModel
17
+ import copy
18
+ import random
19
+ import time
20
+ import safetensors.torch
21
+ from tqdm import tqdm
22
+ from safetensors.torch import load_file
23
+ from huggingface_hub import HfFileSystem, ModelCard
24
+ from huggingface_hub import login, hf_hub_download
25
+ hf_token = os.environ.get("HF_TOKEN")
26
+ login(token=hf_token)
27
+
28
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
29
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
30
+ os.environ["HF_HUB_CACHE"] = cache_path
31
+ os.environ["HF_HOME"] = cache_path
32
+
33
+ #torch.set_float32_matmul_precision("medium")
34
+
35
+ # Load LoRAs from JSON file
36
+ with open('loras.json', 'r') as f:
37
+ loras = json.load(f)
38
+
39
+ # Initialize the base model
40
+ dtype = torch.bfloat16
41
+ base_model = "AlekseyCalvin/Flux-Krea-Blaze_byMintLab_fp8_Diffusers"
42
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to("cuda")
43
+ #pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.float16).to("cuda")
44
+ torch.cuda.empty_cache()
45
+
46
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+
48
+ model_id = ("zer0int/LongCLIP-GmP-ViT-L-14")
49
+ config = CLIPConfig.from_pretrained(model_id)
50
+ config.text_config.max_position_embeddings = 248
51
+ clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True)
52
+ clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=248)
53
+ pipe.tokenizer = clip_processor.tokenizer
54
+ pipe.text_encoder = clip_model.text_model
55
+ pipe.tokenizer_max_length = 248
56
+ pipe.text_encoder.dtype = torch.bfloat16
57
+ #pipe.text_encoder_2 = t5.text_model
58
+
59
+ MAX_SEED = 2**32-1
60
+
61
+ class calculateDuration:
62
+ def __init__(self, activity_name=""):
63
+ self.activity_name = activity_name
64
+
65
+ def __enter__(self):
66
+ self.start_time = time.time()
67
+ return self
68
+
69
+ def __exit__(self, exc_type, exc_value, traceback):
70
+ self.end_time = time.time()
71
+ self.elapsed_time = self.end_time - self.start_time
72
+ if self.activity_name:
73
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
74
+ else:
75
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
76
+
77
+
78
+ def update_selection(evt: gr.SelectData, width, height):
79
+ selected_lora = loras[evt.index]
80
+ new_placeholder = f"Prompt with activator word(s): '{selected_lora['trigger_word']}'! "
81
+ lora_repo = selected_lora["repo"]
82
+ lora_trigger = selected_lora['trigger_word']
83
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}). Prompt using: '{lora_trigger}'!"
84
+ if "aspect" in selected_lora:
85
+ if selected_lora["aspect"] == "portrait":
86
+ width = 768
87
+ height = 1024
88
+ elif selected_lora["aspect"] == "landscape":
89
+ width = 1024
90
+ height = 768
91
+ return (
92
+ gr.update(placeholder=new_placeholder),
93
+ updated_text,
94
+ evt.index,
95
+ width,
96
+ height,
97
+ )
98
+
99
+ @spaces.GPU(duration=80)
100
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
101
+ pipe.to("cuda")
102
+ generator = torch.Generator(device="cuda").manual_seed(seed)
103
+
104
+ with calculateDuration("Generating image"):
105
+ # Generate image
106
+ image = pipe(
107
+ prompt=f"{prompt} {trigger_word}",
108
+ num_inference_steps=steps,
109
+ guidance_scale=cfg_scale,
110
+ width=width,
111
+ height=height,
112
+ generator=generator,
113
+ joint_attention_kwargs={"scale": lora_scale},
114
+ ).images[0]
115
+ return image
116
+
117
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
118
+ if selected_index is None:
119
+ raise gr.Error("You must select a LoRA before proceeding.")
120
+
121
+ selected_lora = loras[selected_index]
122
+ lora_path = selected_lora["repo"]
123
+ trigger_word = selected_lora['trigger_word']
124
+
125
+ # Load LoRA weights
126
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
127
+ if "weights" in selected_lora:
128
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
129
+ else:
130
+ pipe.load_lora_weights(lora_path)
131
+
132
+ # Set random seed for reproducibility
133
+ with calculateDuration("Randomizing seed"):
134
+ if randomize_seed:
135
+ seed = random.randint(0, MAX_SEED)
136
+
137
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
138
+ pipe.to("cpu")
139
+ pipe.unload_lora_weights()
140
+ return image, seed
141
+
142
+ run_lora.zerogpu = True
143
+
144
+ css = '''
145
+ #gen_btn{height: 100%}
146
+ #title{text-align: center}
147
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
148
+ #title img{width: 100px; margin-right: 0.5em}
149
+ #gallery .grid-wrap{height: 10vh}
150
+ '''
151
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
152
+ title = gr.HTML(
153
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
154
+ elem_id="title",
155
+ )
156
+ # Info blob stating what the app is running
157
+ info_blob = gr.HTML(
158
+ """<div id="info_blob"> Img. Manufactory Running On: the Shuttle Jaguar Model (a fast open offshoot of FLUX). Nearly all of the LoRA adapters accessible via this space were trained by us in an extensive progression of inspired experiments and conceptual mini-projects. Check out our poetry translations at WWW.SILVERagePOETS.com Find our music on SoundCloud @ AlekseyCalvin & YouTube @ SilverAgePoets / AlekseyCalvin! </div>"""
159
+ )
160
+
161
+ # Info blob stating what the app is running
162
+ info_blob = gr.HTML(
163
+ """<div id="info_blob"> To reinforce/focus in selected fine-tuned LoRAs (Low-Rank Adapters), add special “trigger" words/phrases to your prompts. </div>"""
164
+ )
165
+ selected_index = gr.State(None)
166
+ with gr.Row():
167
+ with gr.Column(scale=3):
168
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
169
+ with gr.Column(scale=1, elem_id="gen_column"):
170
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
171
+ with gr.Row():
172
+ with gr.Column(scale=3):
173
+ selected_info = gr.Markdown("")
174
+ gallery = gr.Gallery(
175
+ [(item["image"], item["title"]) for item in loras],
176
+ label="LoRA Inventory",
177
+ allow_preview=False,
178
+ columns=3,
179
+ elem_id="gallery"
180
+ )
181
+
182
+ with gr.Column(scale=4):
183
+ result = gr.Image(label="Generated Image")
184
+
185
+ with gr.Row():
186
+ with gr.Accordion("Advanced Settings", open=True):
187
+ with gr.Column():
188
+ with gr.Row():
189
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=.1, value=2.5)
190
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=9)
191
+
192
+ with gr.Row():
193
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
194
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
195
+
196
+ with gr.Row():
197
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
198
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
199
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.5, step=0.01, value=1.0)
200
+
201
+ gallery.select(
202
+ update_selection,
203
+ inputs=[width, height],
204
+ outputs=[prompt, selected_info, selected_index, width, height]
205
+ )
206
+
207
+ gr.on(
208
+ triggers=[generate_button.click, prompt.submit],
209
+ fn=run_lora,
210
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
211
+ outputs=[result, seed]
212
+ )
213
+
214
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
215
+ app.launch()