from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from timm import create_model from transformers import ( AutoConfig, AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel, ) from transformers.utils import ModelOutput from .location_encoder import LocationEncoder class CLOSPConfig(PretrainedConfig): """ Configuration class for CLOSPModel. This class stores the configuration of a CLOSPModel, which is used to instantiate the model according to the specified parameters. """ model_type = "closp" def __init__( self, # Vision model parameters vision_model_key: str = "vit-s", s1_embedding_dim: int = 384, s2_embedding_dim: int = 384, s1_head_dim: int = 0, s2_head_dim: int = 0, # Text model parameters text_model_name_or_path: str = "distilbert-base-uncased", # Location encoder parameters (optional) use_location_encoder: bool = True, location_embedding_dim: int = 512, # General model parameters projection_dim: int = 768, **kwargs, ): super().__init__(**kwargs) self.vision_model_key = vision_model_key self.s1_embedding_dim = s1_embedding_dim self.s2_embedding_dim = s2_embedding_dim self.text_model_name_or_path = text_model_name_or_path self.use_location_encoder = use_location_encoder self.location_embedding_dim = location_embedding_dim self.projection_dim = projection_dim self.s1_head_dim = s1_head_dim self.s2_head_dim = s2_head_dim # --- Structured Model Output --- @dataclass class CLOSPOutput(ModelOutput): """ Base class for CLOSP model's outputs. """ loss: torch.FloatTensor = None logits_per_image: torch.FloatTensor = None logits_per_text: torch.FloatTensor = None logits_per_loc_img: torch.FloatTensor = None logits_per_img_loc: torch.FloatTensor = None image_embeds: torch.FloatTensor = None text_embeds: torch.FloatTensor = None location_embeds: torch.FloatTensor = None class CLOSPModel(PreTrainedModel): config_class = CLOSPConfig def __init__(self, config: CLOSPConfig): super().__init__(config) # --- Vision Encoders --- self.s1_encoder = create_model( config.vision_model_key, in_chans=2, num_classes=config.s1_head_dim, pretrained=False, ) self.s2_encoder = create_model( config.vision_model_key, in_chans=13, num_classes=config.s2_head_dim, pretrained=False, ) self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) # --- Text Encoder --- self.text_model = AutoModel.from_config( AutoConfig.from_pretrained(config.text_model_name_or_path) ) self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) # --- Location Encoder --- if config.use_location_encoder: self.location_encoder = LocationEncoder(512, 2, 256, 10) self.location_projection = nn.Linear( config.location_embedding_dim, config.projection_dim ) def tokenize_text(self, text: str): """Tokenizes input text using the model's tokenizer.""" return self.tokenizer( text, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt", ) def get_image_features(self, image: torch.Tensor) -> torch.Tensor: """Encodes an image tensor into features.""" image = image.float() if image.shape[1] == 2: # Sentinel-1 image_features = self.s1_projection(self.s1_encoder(image)) else: # Sentinel-2 image_features = self.s2_projection(self.s2_encoder(image)) return F.normalize(image_features, p=2, dim=-1) def get_text_features( self, input_ids: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """Encodes text tokens into features.""" text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) text_features = text_outputs.last_hidden_state[:, 0, :] return F.normalize(text_features, p=2, dim=-1) def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: """Encodes coordinates into features.""" if not self.config.use_location_encoder: raise ValueError( "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." ) location_features = self.location_encoder(coords) location_features = self.location_projection(location_features) return F.normalize(location_features, p=2, dim=-1) def forward( self, image: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, coords: torch.Tensor = None, return_loss: bool = False, ) -> CLOSPOutput: image_embeds = self.get_image_features(image) text_embeds = self.get_text_features(input_ids, attention_mask) # Cosine similarity as logits logits_per_image = image_embeds @ text_embeds.T logits_per_text = logits_per_image.T # --- Optional Location Logic --- location_embeds = None logits_per_loc_img = None logits_per_img_loc = None if self.config.use_location_encoder: if coords is None: raise ValueError( "Coordinates must be provided when use_location_encoder is True." ) location_embeds = self.get_location_features(coords) logits_per_loc_img = location_embeds @ image_embeds.T logits_per_img_loc = image_embeds @ location_embeds.T # --- Optional Loss Calculation --- loss = None if return_loss: outputs = [ logits_per_image, logits_per_text, logits_per_loc_img, logits_per_img_loc, ] ground_truth = torch.arange(len(input_ids)).to(self.device) loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] loss = sum(loss) / len(loss) return CLOSPOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, logits_per_loc_img=logits_per_loc_img, logits_per_img_loc=logits_per_img_loc, image_embeds=image_embeds, text_embeds=text_embeds, location_embeds=location_embeds, )