jbilcke-hf HF Staff commited on
Commit
12c84ce
·
verified ·
1 Parent(s): 94c2023

Update utils/wan_wrapper.py

Browse files
Files changed (1) hide show
  1. utils/wan_wrapper.py +13 -4
utils/wan_wrapper.py CHANGED
@@ -22,12 +22,17 @@ class WanTextEncoder(torch.nn.Module):
22
  device=torch.device('cpu')
23
  ).eval().requires_grad_(False)
24
  self.text_encoder.load_state_dict(
25
- torch.load("/repository/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
 
 
26
  map_location='cpu', weights_only=False)
27
  )
28
 
29
  self.tokenizer = HuggingfaceTokenizer(
30
- name="/repository/Wan2.1-T2V-1.3B/google/umt5-xxl/", seq_len=512, clean='whitespace')
 
 
 
31
 
32
  @property
33
  def device(self):
@@ -66,7 +71,9 @@ class WanVAEWrapper(torch.nn.Module):
66
 
67
  # init model
68
  self.model = _video_vae(
69
- pretrained_path="/repository/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
 
 
70
  z_dim=16,
71
  ).eval().requires_grad_(False)
72
 
@@ -115,7 +122,9 @@ class WanVAEWrapper(torch.nn.Module):
115
  class WanDiffusionWrapper(torch.nn.Module):
116
  def __init__(
117
  self,
118
- model_name="Wan2.1-T2V-1.3B",
 
 
119
  timestep_shift=8.0,
120
  is_causal=False,
121
  local_attn_size=-1,
 
22
  device=torch.device('cpu')
23
  ).eval().requires_grad_(False)
24
  self.text_encoder.load_state_dict(
25
+ # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1"
26
+ # torch.load("/repository/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
27
+ torch.load("/repository/wan2.1/models_t5_umt5-xxl-enc-bf16.pth",
28
  map_location='cpu', weights_only=False)
29
  )
30
 
31
  self.tokenizer = HuggingfaceTokenizer(
32
+ # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1"
33
+ #name="/repository/Wan2.1-T2V-1.3B/google/umt5-xxl/",
34
+ name="/repository/wan2.1/google/umt5-xxl/",
35
+ seq_len=512, clean='whitespace')
36
 
37
  @property
38
  def device(self):
 
71
 
72
  # init model
73
  self.model = _video_vae(
74
+ # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1"
75
+ #pretrained_path="/repository/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
76
+ pretrained_path="/repository/wan2.1/Wan2.1_VAE.pth",
77
  z_dim=16,
78
  ).eval().requires_grad_(False)
79
 
 
122
  class WanDiffusionWrapper(torch.nn.Module):
123
  def __init__(
124
  self,
125
+ # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1"
126
+ #model_name="Wan2.1-T2V-1.3B",
127
+ model_name="wan2.1",
128
  timestep_shift=8.0,
129
  is_causal=False,
130
  local_attn_size=-1,