DarthReca commited on
Commit
6cd35b4
·
verified ·
1 Parent(s): 1ead1df

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ spherical_armonics.py filter=lfs diff=lfs merge=lfs -text
location_encoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+
3
+ import math
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .positional_encoding import SphericalHarmonics
11
+
12
+
13
+ class LocationEncoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim_hidden: int,
17
+ num_layers: int,
18
+ dim_out: int,
19
+ legendre_polys: int = 10,
20
+ ):
21
+ super().__init__()
22
+ self.posenc = SphericalHarmonics(legendre_polys=legendre_polys)
23
+ self.nnet = SirenNet(
24
+ dim_in=self.posenc.embedding_dim,
25
+ dim_hidden=dim_hidden,
26
+ num_layers=num_layers,
27
+ dim_out=dim_out,
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.posenc(x)
32
+ return self.nnet(x)
33
+
34
+
35
+ class SirenNet(nn.Module):
36
+ """Sinusoidal Representation Network (SIREN)"""
37
+
38
+ def __init__(
39
+ self,
40
+ dim_in,
41
+ dim_hidden,
42
+ dim_out,
43
+ num_layers,
44
+ w0=1.0,
45
+ w0_initial=30.0,
46
+ use_bias=True,
47
+ final_activation=None,
48
+ degreeinput=False,
49
+ dropout=True,
50
+ ):
51
+ super().__init__()
52
+ self.num_layers = num_layers
53
+ self.dim_hidden = dim_hidden
54
+ self.degreeinput = degreeinput
55
+
56
+ self.layers = nn.ModuleList([])
57
+ for ind in range(num_layers):
58
+ is_first = ind == 0
59
+ layer_w0 = w0_initial if is_first else w0
60
+ layer_dim_in = dim_in if is_first else dim_hidden
61
+
62
+ self.layers.append(
63
+ Siren(
64
+ dim_in=layer_dim_in,
65
+ dim_out=dim_hidden,
66
+ w0=layer_w0,
67
+ use_bias=use_bias,
68
+ is_first=is_first,
69
+ dropout=dropout,
70
+ )
71
+ )
72
+
73
+ final_activation = (
74
+ nn.Identity() if not exists(final_activation) else final_activation
75
+ )
76
+ self.last_layer = Siren(
77
+ dim_in=dim_hidden,
78
+ dim_out=dim_out,
79
+ w0=w0,
80
+ use_bias=use_bias,
81
+ activation=final_activation,
82
+ dropout=False,
83
+ )
84
+
85
+ def forward(self, x, mods=None):
86
+ # do some normalization to bring degrees in a -pi to pi range
87
+ if self.degreeinput:
88
+ x = torch.deg2rad(x) - torch.pi
89
+
90
+ mods = cast_tuple(mods, self.num_layers)
91
+
92
+ for layer, mod in zip(self.layers, mods):
93
+ x = layer(x)
94
+
95
+ if exists(mod):
96
+ x *= rearrange(mod, "d -> () d")
97
+
98
+ return self.last_layer(x)
99
+
100
+
101
+ class Sine(nn.Module):
102
+ def __init__(self, w0=1.0):
103
+ super().__init__()
104
+ self.w0 = w0
105
+
106
+ def forward(self, x):
107
+ return torch.sin(self.w0 * x)
108
+
109
+
110
+ class Siren(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim_in,
114
+ dim_out,
115
+ w0=1.0,
116
+ c=6.0,
117
+ is_first=False,
118
+ use_bias=True,
119
+ activation=None,
120
+ dropout=False,
121
+ ):
122
+ super().__init__()
123
+ self.dim_in = dim_in
124
+ self.is_first = is_first
125
+ self.dim_out = dim_out
126
+ self.dropout = dropout
127
+
128
+ weight = torch.zeros(dim_out, dim_in)
129
+ bias = torch.zeros(dim_out) if use_bias else None
130
+ self.init_(weight, bias, c=c, w0=w0)
131
+
132
+ self.weight = nn.Parameter(weight)
133
+ self.bias = nn.Parameter(bias) if use_bias else None
134
+ self.activation = Sine(w0) if activation is None else activation
135
+
136
+ def init_(self, weight, bias, c, w0):
137
+ dim = self.dim_in
138
+
139
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
140
+ weight.uniform_(-w_std, w_std)
141
+
142
+ if exists(bias):
143
+ bias.uniform_(-w_std, w_std)
144
+
145
+ def forward(self, x):
146
+ out = F.linear(x, self.weight, self.bias)
147
+ if self.dropout:
148
+ out = F.dropout(out, training=self.training)
149
+ out = self.activation(out)
150
+ return out
151
+
152
+
153
+ def exists(val):
154
+ return val is not None
155
+
156
+
157
+ def cast_tuple(val, repeat=1):
158
+ return val if isinstance(val, tuple) else ((val,) * repeat)
modeling_closp.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from timm import create_model
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoTokenizer,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ )
14
+ from transformers.utils import ModelOutput
15
+
16
+ from .location_encoder import LocationEncoder
17
+
18
+
19
+ class CLOSPConfig(PretrainedConfig):
20
+ """
21
+ Configuration class for CLOSPModel.
22
+
23
+ This class stores the configuration of a CLOSPModel, which is used to instantiate the model
24
+ according to the specified parameters.
25
+ """
26
+
27
+ model_type = "closp"
28
+
29
+ def __init__(
30
+ self,
31
+ # Vision model parameters
32
+ vision_model_key: str = "vit-s",
33
+ s1_embedding_dim: int = 384,
34
+ s2_embedding_dim: int = 384,
35
+ s1_head_dim: int = 0,
36
+ s2_head_dim: int = 0,
37
+ # Text model parameters
38
+ text_model_name_or_path: str = "distilbert-base-uncased",
39
+ # Location encoder parameters (optional)
40
+ use_location_encoder: bool = True,
41
+ location_embedding_dim: int = 512,
42
+ # General model parameters
43
+ projection_dim: int = 768,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+ self.vision_model_key = vision_model_key
48
+ self.s1_embedding_dim = s1_embedding_dim
49
+ self.s2_embedding_dim = s2_embedding_dim
50
+ self.text_model_name_or_path = text_model_name_or_path
51
+ self.use_location_encoder = use_location_encoder
52
+ self.location_embedding_dim = location_embedding_dim
53
+ self.projection_dim = projection_dim
54
+ self.s1_head_dim = s1_head_dim
55
+ self.s2_head_dim = s2_head_dim
56
+
57
+
58
+ # --- Structured Model Output ---
59
+ @dataclass
60
+ class CLOSPOutput(ModelOutput):
61
+ """
62
+ Base class for CLOSP model's outputs.
63
+ """
64
+
65
+ loss: torch.FloatTensor = None
66
+ logits_per_image: torch.FloatTensor = None
67
+ logits_per_text: torch.FloatTensor = None
68
+ logits_per_loc_img: torch.FloatTensor = None
69
+ logits_per_img_loc: torch.FloatTensor = None
70
+ image_embeds: torch.FloatTensor = None
71
+ text_embeds: torch.FloatTensor = None
72
+ location_embeds: torch.FloatTensor = None
73
+
74
+
75
+ class CLOSPModel(PreTrainedModel):
76
+ config_class = CLOSPConfig
77
+
78
+ def __init__(self, config: CLOSPConfig):
79
+ super().__init__(config)
80
+ # --- Vision Encoders ---
81
+ self.s1_encoder = create_model(
82
+ config.vision_model_key,
83
+ in_chans=2,
84
+ num_classes=config.s1_head_dim,
85
+ pretrained=False,
86
+ )
87
+ self.s2_encoder = create_model(
88
+ config.vision_model_key,
89
+ in_chans=13,
90
+ num_classes=config.s2_head_dim,
91
+ pretrained=False,
92
+ )
93
+ self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim)
94
+ self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim)
95
+
96
+ # --- Text Encoder ---
97
+ self.text_model = AutoModel.from_config(
98
+ AutoConfig.from_pretrained(config.text_model_name_or_path)
99
+ )
100
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
101
+
102
+ # --- Location Encoder ---
103
+ if config.use_location_encoder:
104
+ self.location_encoder = LocationEncoder(512, 2, 256, 10)
105
+ self.location_projection = nn.Linear(
106
+ config.location_embedding_dim, config.projection_dim
107
+ )
108
+
109
+ def tokenize_text(self, text: str):
110
+ """Tokenizes input text using the model's tokenizer."""
111
+ return self.tokenizer(
112
+ text,
113
+ padding="max_length",
114
+ truncation=True,
115
+ max_length=self.tokenizer.model_max_length,
116
+ return_tensors="pt",
117
+ )
118
+
119
+ def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
120
+ """Encodes an image tensor into features."""
121
+ image = image.float()
122
+ if image.shape[1] == 2: # Sentinel-1
123
+ image_features = self.s1_projection(self.s1_encoder(image))
124
+ else: # Sentinel-2
125
+ image_features = self.s2_projection(self.s2_encoder(image))
126
+
127
+ return F.normalize(image_features, p=2, dim=-1)
128
+
129
+ def get_text_features(
130
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor
131
+ ) -> torch.Tensor:
132
+ """Encodes text tokens into features."""
133
+ text_outputs = self.text_model(
134
+ input_ids=input_ids,
135
+ attention_mask=attention_mask,
136
+ output_hidden_states=True,
137
+ )
138
+ text_features = text_outputs.last_hidden_state[:, 0, :]
139
+ return F.normalize(text_features, p=2, dim=-1)
140
+
141
+ def get_location_features(self, coords: torch.Tensor) -> torch.Tensor:
142
+ """Encodes coordinates into features."""
143
+ if not self.config.use_location_encoder:
144
+ raise ValueError(
145
+ "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config."
146
+ )
147
+ location_features = self.location_encoder(coords)
148
+ location_features = self.location_projection(location_features)
149
+ return F.normalize(location_features, p=2, dim=-1)
150
+
151
+ def forward(
152
+ self,
153
+ image: torch.Tensor,
154
+ input_ids: torch.Tensor,
155
+ attention_mask: torch.Tensor,
156
+ coords: torch.Tensor = None,
157
+ return_loss: bool = False,
158
+ ) -> CLOSPOutput:
159
+ image_embeds = self.get_image_features(image)
160
+ text_embeds = self.get_text_features(input_ids, attention_mask)
161
+
162
+ # Cosine similarity as logits
163
+ logits_per_image = image_embeds @ text_embeds.T
164
+ logits_per_text = logits_per_image.T
165
+
166
+ # --- Optional Location Logic ---
167
+ location_embeds = None
168
+ logits_per_loc_img = None
169
+ logits_per_img_loc = None
170
+
171
+ if self.config.use_location_encoder:
172
+ if coords is None:
173
+ raise ValueError(
174
+ "Coordinates must be provided when use_location_encoder is True."
175
+ )
176
+ location_embeds = self.get_location_features(coords)
177
+ logits_per_loc_img = location_embeds @ image_embeds.T
178
+ logits_per_img_loc = image_embeds @ location_embeds.T
179
+
180
+ # --- Optional Loss Calculation ---
181
+ loss = None
182
+ if return_loss:
183
+ outputs = [
184
+ logits_per_image,
185
+ logits_per_text,
186
+ logits_per_loc_img,
187
+ logits_per_img_loc,
188
+ ]
189
+ ground_truth = torch.arange(len(input_ids)).to(self.device)
190
+ loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None]
191
+ loss = sum(loss) / len(loss)
192
+
193
+ return CLOSPOutput(
194
+ loss=loss,
195
+ logits_per_image=logits_per_image,
196
+ logits_per_text=logits_per_text,
197
+ logits_per_loc_img=logits_per_loc_img,
198
+ logits_per_img_loc=logits_per_img_loc,
199
+ image_embeds=image_embeds,
200
+ text_embeds=text_embeds,
201
+ location_embeds=location_embeds,
202
+ )
positional_encoding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .spherical_armonics import SH as SH_analytic
9
+
10
+
11
+ class SphericalHarmonics(nn.Module):
12
+ """
13
+ Spherical Harmonics locaiton encoder
14
+ """
15
+
16
+ def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"):
17
+ """
18
+ legendre_polys: determines the number of legendre polynomials.
19
+ more polynomials lead more fine-grained resolutions
20
+ calculation of spherical harmonics:
21
+ analytic uses pre-computed equations. This is exact, but works only up to degree 50,
22
+ closed-form uses one equation but is computationally slower (especially for high degrees)
23
+ """
24
+ super(SphericalHarmonics, self).__init__()
25
+ self.L, self.M = int(legendre_polys), int(legendre_polys)
26
+ self.embedding_dim = self.L * self.M
27
+
28
+ if harmonics_calculation == "closed-form":
29
+ self.SH = SH_closed_form
30
+ elif harmonics_calculation == "analytic":
31
+ self.SH = SH_analytic
32
+
33
+ def forward(self, lonlat):
34
+ lon, lat = lonlat[:, 0], lonlat[:, 1]
35
+
36
+ # convert degree to rad
37
+ phi = torch.deg2rad(lon + 180)
38
+ theta = torch.deg2rad(lat + 90)
39
+ """
40
+ greater_than_50 = (lon > 50).any() or (lat > 50).any()
41
+ if greater_than_50:
42
+ SH = SH_closed_form
43
+ else:
44
+ SH = SH_analytic
45
+ """
46
+ SH = self.SH
47
+
48
+ Y = []
49
+ for l in range(self.L):
50
+ for m in range(-l, l + 1):
51
+ y = SH(m, l, phi, theta)
52
+ if isinstance(y, float):
53
+ y = y * torch.ones_like(phi)
54
+ if y.isnan().any():
55
+ print(m, l, y)
56
+ Y.append(y)
57
+
58
+ return torch.stack(Y, dim=-1)
59
+
60
+
61
+ ####################### Spherical Harmonics utilities ########################
62
+ # Code copied from https://github.com/BachiLi/redner/blob/master/pyredner/utils.py
63
+ # Code adapted from "Spherical Harmonic Lighting: The Gritty Details", Robin Green
64
+ # http://silviojemma.com/public/papers/lighting/spherical-harmonic-lighting.pdf
65
+ def associated_legendre_polynomial(l, m, x):
66
+ pmm = torch.ones_like(x)
67
+ if m > 0:
68
+ somx2 = torch.sqrt((1 - x) * (1 + x))
69
+ fact = 1.0
70
+ for i in range(1, m + 1):
71
+ pmm = pmm * (-fact) * somx2
72
+ fact += 2.0
73
+ if l == m:
74
+ return pmm
75
+ pmmp1 = x * (2.0 * m + 1.0) * pmm
76
+ if l == m + 1:
77
+ return pmmp1
78
+ pll = torch.zeros_like(x)
79
+ for ll in range(m + 2, l + 1):
80
+ pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
81
+ pmm = pmmp1
82
+ pmmp1 = pll
83
+ return pll
84
+
85
+
86
+ def SH_renormalization(l, m):
87
+ return math.sqrt(
88
+ (2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m))
89
+ )
90
+
91
+
92
+ def SH_closed_form(m, l, phi, theta):
93
+ if m == 0:
94
+ return SH_renormalization(l, m) * associated_legendre_polynomial(
95
+ l, m, torch.cos(theta)
96
+ )
97
+ elif m > 0:
98
+ return (
99
+ math.sqrt(2.0)
100
+ * SH_renormalization(l, m)
101
+ * torch.cos(m * phi)
102
+ * associated_legendre_polynomial(l, m, torch.cos(theta))
103
+ )
104
+ else:
105
+ return (
106
+ math.sqrt(2.0)
107
+ * SH_renormalization(l, -m)
108
+ * torch.sin(-m * phi)
109
+ * associated_legendre_polynomial(l, -m, torch.cos(theta))
110
+ )
spherical_armonics.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fc4e9b49abb4e81411376fc6d09b1281aa8ed96cef64b7aa95cc4aeeccb97a4
3
+ size 10994723