gokaygokay commited on
Commit
40c933d
·
verified ·
1 Parent(s): bae69a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from kolors.models.modeling_chatglm import ChatGLMModel
13
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
14
  from diffusers import UNet2DConditionModel, AutoencoderKL
15
  from diffusers import EulerDiscreteScheduler
 
16
 
17
  # Initialize models
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -21,6 +22,8 @@ dtype = torch.float16
21
  # Download Kolors model
22
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
23
 
 
 
24
  # Load Kolors models
25
  text_encoder = ChatGLMModel.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'), torch_dtype=dtype).to(device)
26
  tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'))
 
13
  from kolors.models.tokenization_chatglm import ChatGLMTokenizer
14
  from diffusers import UNet2DConditionModel, AutoencoderKL
15
  from diffusers import EulerDiscreteScheduler
16
+ import subprocess
17
 
18
  # Initialize models
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  # Download Kolors model
23
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
24
 
25
+ subprocess.run('pip install flash-attn --no-build-isolation', shell=True)
26
+
27
  # Load Kolors models
28
  text_encoder = ChatGLMModel.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'), torch_dtype=dtype).to(device)
29
  tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'))