lavoies commited on
Commit
2ee7a48
·
verified ·
1 Parent(s): 65a6d9f

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. configuration_llip.py +12 -0
  3. modeling_llip.py +364 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "LlipModel"
4
  ],
 
 
 
 
5
  "init_logit_bias": -10,
6
  "initializer_factor": 1.0,
7
  "logit_scale_init_value": 2.6592,
 
2
  "architectures": [
3
  "LlipModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_llip.LlipConfig",
7
+ "AutoModel": "modeling_llip.LlipModel"
8
+ },
9
  "init_logit_bias": -10,
10
  "initializer_factor": 1.0,
11
  "logit_scale_init_value": 2.6592,
configuration_llip.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPConfig
2
+
3
+
4
+ class LlipConfig(CLIPConfig):
5
+ model_type = "llip"
6
+
7
+ def __init__(self, use_norm=True, ncls=64, num_heads=8, temp=1.0, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.use_norm = use_norm
10
+ self.num_heads = num_heads
11
+ self.temp = temp
12
+ # TODO: Get the vision_config parameters
modeling_llip.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DLC DiT replaces class label conditioning with DLC conditioning
2
+
3
+ class labels are a single discrete token between 0 and num_embeds_ada_norm-1
4
+ DLCs are a fixed-length sequence of L discrete tokens between 0 and V-1
5
+
6
+ we replace LabelEmbedder with DLCEmbedder
7
+ - maintain the embedding matrix and drop_token
8
+ - but apply it to a DLC sequence of L tokens, instead of a single class
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from transformers import (
17
+ CLIPModel,
18
+ PretrainedConfig,
19
+ )
20
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
21
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
22
+ from transformers.utils import ModelOutput
23
+
24
+ from configuration_llip import LlipConfig
25
+
26
+
27
+ @dataclass
28
+ class LlipOutput(ModelOutput):
29
+ loss: Optional[float] = None
30
+ K: Optional[torch.tensor] = None
31
+ V: Optional[torch.tensor] = None
32
+ Q: Optional[torch.tensor] = None
33
+ image_embeds: Optional[torch.tensor] = None
34
+ text_embeds: Optional[torch.tensor] = None
35
+ logit_scale: Optional[torch.tensor] = None
36
+ logit_bias: Optional[torch.tensor] = None
37
+
38
+
39
+ class LlipPred(torch.nn.Module):
40
+ def __init__(self, embed_dim):
41
+ super().__init__()
42
+ scale_out = embed_dim**-0.5
43
+ self.out_proj = nn.Parameter(scale_out * torch.randn(embed_dim, embed_dim))
44
+
45
+ def cross_attention(self, K, Q, V, weight_scale, out_proj):
46
+ attn = (torch.einsum("vhnd,thd->vthn", K, Q) / weight_scale).softmax(-1)
47
+ zv = torch.einsum("vthn,vhnd->vthd", attn, V).reshape(
48
+ K.shape[0], Q.shape[0], -1
49
+ )
50
+ zv = zv @ out_proj
51
+ return zv
52
+
53
+ def forward(self, K, Q, V, weight_scale):
54
+ out = self.cross_attention(K, Q, V, weight_scale, self.out_proj)
55
+ return out
56
+
57
+
58
+ def torch_int(x):
59
+ """
60
+ Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
61
+ """
62
+ import torch
63
+
64
+ return (
65
+ x.to(torch.int64)
66
+ if torch.jit.is_tracing() and isinstance(x, torch.Tensor)
67
+ else int(x)
68
+ )
69
+
70
+
71
+ class LlipVisionTransformer(CLIPVisionTransformer):
72
+ def __init__(self, config):
73
+ super().__init__(config)
74
+ self.embeddings = LlipVisionEmbeddings(config)
75
+
76
+ def forward(
77
+ self,
78
+ pixel_values=None,
79
+ output_attentions=None,
80
+ output_hidden_states=None,
81
+ interpolate_pos_encoding=False,
82
+ ):
83
+ output_attentions = (
84
+ output_attentions
85
+ if output_attentions is not None
86
+ else self.config.output_attentions
87
+ )
88
+ output_hidden_states = (
89
+ output_hidden_states
90
+ if output_hidden_states is not None
91
+ else self.config.output_hidden_states
92
+ )
93
+
94
+ if pixel_values is None:
95
+ raise ValueError("You have to specify pixel_values")
96
+
97
+ hidden_states = self.embeddings(
98
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
99
+ )
100
+ hidden_states = self.pre_layrnorm(hidden_states)
101
+
102
+ encoder_outputs = self.encoder(
103
+ inputs_embeds=hidden_states,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ )
107
+
108
+ last_hidden_state = encoder_outputs.last_hidden_state
109
+ pooled_output = last_hidden_state[:, : self.config.ncls, :]
110
+ pooled_output = self.post_layernorm(pooled_output)
111
+
112
+ return BaseModelOutputWithPooling(
113
+ last_hidden_state=last_hidden_state,
114
+ pooler_output=pooled_output,
115
+ hidden_states=encoder_outputs.hidden_states,
116
+ attentions=encoder_outputs.attentions,
117
+ )
118
+
119
+
120
+ class LlipVisionEmbeddings(torch.nn.Module):
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ self.embed_dim = config.hidden_size
124
+ self.image_size = config.image_size
125
+ self.patch_size = config.patch_size
126
+ self.ncls = config.ncls
127
+
128
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
129
+
130
+ self.patch_embedding = nn.Conv2d(
131
+ in_channels=config.num_channels,
132
+ out_channels=self.embed_dim,
133
+ kernel_size=self.patch_size,
134
+ stride=self.patch_size,
135
+ bias=False,
136
+ )
137
+
138
+ self.num_patches = (self.image_size // self.patch_size) ** 2
139
+ self.num_positions = self.num_patches + self.ncls
140
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
141
+ self.register_buffer(
142
+ "position_ids",
143
+ torch.arange(self.num_positions).expand((1, -1)),
144
+ persistent=False,
145
+ )
146
+
147
+ def interpolate_pos_encoding(
148
+ self, embeddings: torch.Tensor, height: int, width: int
149
+ ) -> torch.Tensor:
150
+ """
151
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
152
+ images. This method is also adapted to support torch.jit tracing.
153
+
154
+ Adapted from:
155
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
156
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
157
+ """
158
+
159
+ num_patches = embeddings.shape[1] - 1
160
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
161
+ num_positions = position_embedding.shape[1] - 1
162
+
163
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
164
+ if (
165
+ not torch.jit.is_tracing()
166
+ and num_patches == num_positions
167
+ and height == width
168
+ ):
169
+ return self.position_embedding(self.position_ids)
170
+
171
+ class_pos_embed = position_embedding[:, :1]
172
+ patch_pos_embed = position_embedding[:, 1:]
173
+
174
+ dim = embeddings.shape[-1]
175
+
176
+ new_height = height // self.patch_size
177
+ new_width = width // self.patch_size
178
+
179
+ sqrt_num_positions = torch_int(num_positions**0.5)
180
+ patch_pos_embed = patch_pos_embed.reshape(
181
+ 1, sqrt_num_positions, sqrt_num_positions, dim
182
+ )
183
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
184
+
185
+ patch_pos_embed = nn.functional.interpolate(
186
+ patch_pos_embed,
187
+ size=(new_height, new_width),
188
+ mode="bicubic",
189
+ align_corners=False,
190
+ )
191
+
192
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
193
+
194
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
195
+ self.class_embedding = nn.Parameter(1, self.ncls, torch.randn(self.embed_dim))
196
+
197
+ def forward(
198
+ self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False
199
+ ) -> torch.Tensor:
200
+ batch_size, _, height, width = pixel_values.shape
201
+ if not interpolate_pos_encoding and (
202
+ height != self.image_size or width != self.image_size
203
+ ):
204
+ raise ValueError(
205
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
206
+ )
207
+ target_dtype = self.patch_embedding.weight.dtype
208
+ patch_embeds = self.patch_embedding(
209
+ pixel_values.to(dtype=target_dtype)
210
+ ) # shape = [*, width, grid, grid]
211
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
212
+
213
+ class_embeds = self.class_embedding.expand(batch_size, self.ncls, -1)
214
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
215
+ if interpolate_pos_encoding:
216
+ embeddings = embeddings + self.interpolate_pos_encoding(
217
+ embeddings, height, width
218
+ )
219
+ else:
220
+ embeddings = embeddings + self.position_embedding(self.position_ids)
221
+ return embeddings
222
+
223
+
224
+ class LlipModel(CLIPModel):
225
+ config_class = LlipConfig
226
+
227
+ def __init__(self, *args, **kwargs):
228
+ # we use dlc_embed_l and dlc_embed_v instead of num_embeds_ada_norm_zero
229
+ # we still need to set num_embeds_ada_norm_zero since there's a check in DiT code
230
+ # but it will be overridden in our code with DLCEmbedding
231
+ super().__init__(*args, **kwargs)
232
+
233
+ self.visual_projection = None
234
+ # self.config.vision_config is broken.
235
+ self.vision_model = LlipVisionTransformer(self.config.vision_config)
236
+ ncls = self.config.vision_config.ncls
237
+ embed_dim = self.config.projection_dim
238
+ self.num_heads = self.config.num_heads
239
+
240
+ scale_visual = self.config.vision_config.hidden_size**-0.5
241
+ if self.config.vision_config.pass_all_tokens:
242
+ num_proj = self.vision_model.embeddings.positional_embedding.weight.size(0)
243
+ else:
244
+ num_proj = ncls
245
+ self.v_proj = nn.Parameter(
246
+ scale_visual
247
+ * torch.randn(num_proj, self.config.vision_config.hidden_size, embed_dim)
248
+ )
249
+ self.k_proj = nn.Parameter(
250
+ scale_visual
251
+ * torch.randn(num_proj, self.config.vision_config.hidden_size, embed_dim)
252
+ )
253
+
254
+ scale_text = self.config.text_config.hidden_size**-0.5
255
+ self.q_proj = nn.Parameter(
256
+ scale_text * torch.randn(self.config.text_config.hidden_size, embed_dim)
257
+ )
258
+ self.logit_bias = -10
259
+
260
+ if self.config.use_norm:
261
+ self.K_norm = nn.LayerNorm(embed_dim)
262
+ self.Q_norm = nn.LayerNorm(embed_dim)
263
+ self.V_norm = nn.LayerNorm(embed_dim)
264
+ else:
265
+ self.K_norm = nn.Identity()
266
+ self.Q_norm = nn.Identity()
267
+ self.V_norm = nn.Identity()
268
+
269
+ self.pred = LlipPred(embed_dim)
270
+
271
+ def get_image_features(self, image):
272
+ """
273
+ Returns K, V
274
+ """
275
+ h = self.vision_model(image).pooler_output
276
+ K = h.transpose(0, 1) @ self.k_proj
277
+ V = h.transpose(0, 1) @ self.v_proj
278
+ N, B, C = K.shape
279
+ K = self.K_norm(K)
280
+ V = self.V_norm(V)
281
+ K = K.reshape(N, B, self.num_heads, C // self.num_heads).permute(
282
+ 1, 2, 0, 3
283
+ ) # [B, num_heads, N, D]
284
+ V = V.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)
285
+ return K, V
286
+
287
+ def get_text_features(self, text):
288
+ """
289
+ Returns Q, zt
290
+ """
291
+ # h = self.token_embedding(text) # [batch_size, n_ctx, d_model]
292
+
293
+ # h = h + self.positional_embedding
294
+ # h = h.permute(1, 0, 2) # NLD -> LND
295
+ # h = self.text_model(h, attn_mask=self.attn_mask).last_hidden_state
296
+ # h = h.permute(1, 0, 2) # LND -> NLD
297
+ # h = self.ln_final(h)
298
+
299
+ # # x.shape = [batch_size, n_ctx, transformer.width]
300
+ # # take features from the eot embedding (eot_token is the highest number in each sequence)
301
+ # h = h[torch.arange(h.shape[0]), text.argmax(dim=-1)]
302
+ h = self.text_model(text).pooler_output
303
+
304
+ Q = h @ self.q_proj
305
+ B, C = Q.shape
306
+ Q = self.Q_norm(Q)
307
+ Q = Q.reshape(B, self.num_heads, C // self.num_heads)
308
+ zt = self.text_projection(h)
309
+ return Q, zt
310
+
311
+ def forward(
312
+ self,
313
+ input_ids,
314
+ pixel_values,
315
+ clamp_logit_scale_to=None,
316
+ compute_image_embeds=False,
317
+ compute_loss=False,
318
+ return_dict=False,
319
+ ):
320
+ """
321
+ Returns (K, V), (Q, zt), logit_scale, logit_bias
322
+ """
323
+ K, V = self.get_image_features(pixel_values)
324
+ Q, zt = self.get_text_features(input_ids)
325
+
326
+ if clamp_logit_scale_to is not None:
327
+ with torch.no_grad():
328
+ self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
329
+
330
+ loss = None
331
+ image_embeds = None
332
+ if compute_image_embeds:
333
+ image_embeds = self.pred(K, Q, V, self.config.temp)
334
+ if compute_loss:
335
+ assert compute_image_embeds
336
+ normalized_image_embeds = torch.nn.functional.normalize(
337
+ image_embeds, dim=-1
338
+ )
339
+ normalized_text_embeds = torch.nn.functional.normalize(zt, dim=-1)
340
+ logits = self.logit_scale.exp() * (
341
+ normalized_text_embeds[None] * normalized_image_embeds
342
+ )
343
+ logits += self.logit_bias
344
+ labels = -torch.ones(
345
+ (len(logits), len(logits)), device=logits.device, dtype=logits.dtype
346
+ )
347
+ labels = (
348
+ 2 * torch.eye(len(logits), device=logits.device, dtype=logits.dtype)
349
+ + labels
350
+ )
351
+ loss = -torch.nn.functional.logsigmoid(labels * logits).sum() / len(
352
+ image_embeds
353
+ )
354
+
355
+ return LlipOutput(
356
+ loss=loss,
357
+ K=K,
358
+ V=V,
359
+ Q=Q,
360
+ text_embeds=zt,
361
+ image_embeds=image_embeds,
362
+ logit_scale=self.logit_scale.exp(),
363
+ logit_bias=self.logit_bias,
364
+ )