File size: 7,013 Bytes
e9aa1c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
)
from transformers.utils import ModelOutput
from .location_encoder import LocationEncoder
class CLOSPConfig(PretrainedConfig):
"""
Configuration class for CLOSPModel.
This class stores the configuration of a CLOSPModel, which is used to instantiate the model
according to the specified parameters.
"""
model_type = "closp"
def __init__(
self,
# Vision model parameters
vision_model_key: str = "vit-s",
s1_embedding_dim: int = 384,
s2_embedding_dim: int = 384,
s1_head_dim: int = 0,
s2_head_dim: int = 0,
# Text model parameters
text_model_name_or_path: str = "distilbert-base-uncased",
# Location encoder parameters (optional)
use_location_encoder: bool = True,
location_embedding_dim: int = 512,
# General model parameters
projection_dim: int = 768,
**kwargs,
):
super().__init__(**kwargs)
self.vision_model_key = vision_model_key
self.s1_embedding_dim = s1_embedding_dim
self.s2_embedding_dim = s2_embedding_dim
self.text_model_name_or_path = text_model_name_or_path
self.use_location_encoder = use_location_encoder
self.location_embedding_dim = location_embedding_dim
self.projection_dim = projection_dim
self.s1_head_dim = s1_head_dim
self.s2_head_dim = s2_head_dim
# --- Structured Model Output ---
@dataclass
class CLOSPOutput(ModelOutput):
"""
Base class for CLOSP model's outputs.
"""
loss: torch.FloatTensor = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
logits_per_loc_img: torch.FloatTensor = None
logits_per_img_loc: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
location_embeds: torch.FloatTensor = None
class CLOSPModel(PreTrainedModel):
config_class = CLOSPConfig
def __init__(self, config: CLOSPConfig):
super().__init__(config)
# --- Vision Encoders ---
self.s1_encoder = create_model(
config.vision_model_key,
in_chans=2,
num_classes=config.s1_head_dim,
pretrained=False,
)
self.s2_encoder = create_model(
config.vision_model_key,
in_chans=13,
num_classes=config.s2_head_dim,
pretrained=False,
)
self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim)
self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim)
# --- Text Encoder ---
self.text_model = AutoModel.from_config(
AutoConfig.from_pretrained(config.text_model_name_or_path)
)
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
# --- Location Encoder ---
if config.use_location_encoder:
self.location_encoder = LocationEncoder(512, 2, 256, 10)
self.location_projection = nn.Linear(
config.location_embedding_dim, config.projection_dim
)
def tokenize_text(self, text: str):
"""Tokenizes input text using the model's tokenizer."""
return self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
"""Encodes an image tensor into features."""
image = image.float()
if image.shape[1] == 2: # Sentinel-1
image_features = self.s1_projection(self.s1_encoder(image))
else: # Sentinel-2
image_features = self.s2_projection(self.s2_encoder(image))
return F.normalize(image_features, p=2, dim=-1)
def get_text_features(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Encodes text tokens into features."""
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
text_features = text_outputs.last_hidden_state[:, 0, :]
return F.normalize(text_features, p=2, dim=-1)
def get_location_features(self, coords: torch.Tensor) -> torch.Tensor:
"""Encodes coordinates into features."""
if not self.config.use_location_encoder:
raise ValueError(
"Location encoder is not enabled for this model. Set `use_location_encoder=True` in config."
)
location_features = self.location_encoder(coords)
location_features = self.location_projection(location_features)
return F.normalize(location_features, p=2, dim=-1)
def forward(
self,
image: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
coords: torch.Tensor = None,
return_loss: bool = False,
) -> CLOSPOutput:
image_embeds = self.get_image_features(image)
text_embeds = self.get_text_features(input_ids, attention_mask)
# Cosine similarity as logits
logits_per_image = image_embeds @ text_embeds.T
logits_per_text = logits_per_image.T
# --- Optional Location Logic ---
location_embeds = None
logits_per_loc_img = None
logits_per_img_loc = None
if self.config.use_location_encoder:
if coords is None:
raise ValueError(
"Coordinates must be provided when use_location_encoder is True."
)
location_embeds = self.get_location_features(coords)
logits_per_loc_img = location_embeds @ image_embeds.T
logits_per_img_loc = image_embeds @ location_embeds.T
# --- Optional Loss Calculation ---
loss = None
if return_loss:
outputs = [
logits_per_image,
logits_per_text,
logits_per_loc_img,
logits_per_img_loc,
]
ground_truth = torch.arange(len(input_ids)).to(self.device)
loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None]
loss = sum(loss) / len(loss)
return CLOSPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
logits_per_loc_img=logits_per_loc_img,
logits_per_img_loc=logits_per_img_loc,
image_embeds=image_embeds,
text_embeds=text_embeds,
location_embeds=location_embeds,
)
|