CLOSP-VL / modeling_closp.py
DarthReca's picture
Upload 4 files
6cd35b4 verified
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
)
from transformers.utils import ModelOutput
from .location_encoder import LocationEncoder
class CLOSPConfig(PretrainedConfig):
"""
Configuration class for CLOSPModel.
This class stores the configuration of a CLOSPModel, which is used to instantiate the model
according to the specified parameters.
"""
model_type = "closp"
def __init__(
self,
# Vision model parameters
vision_model_key: str = "vit-s",
s1_embedding_dim: int = 384,
s2_embedding_dim: int = 384,
s1_head_dim: int = 0,
s2_head_dim: int = 0,
# Text model parameters
text_model_name_or_path: str = "distilbert-base-uncased",
# Location encoder parameters (optional)
use_location_encoder: bool = True,
location_embedding_dim: int = 512,
# General model parameters
projection_dim: int = 768,
**kwargs,
):
super().__init__(**kwargs)
self.vision_model_key = vision_model_key
self.s1_embedding_dim = s1_embedding_dim
self.s2_embedding_dim = s2_embedding_dim
self.text_model_name_or_path = text_model_name_or_path
self.use_location_encoder = use_location_encoder
self.location_embedding_dim = location_embedding_dim
self.projection_dim = projection_dim
self.s1_head_dim = s1_head_dim
self.s2_head_dim = s2_head_dim
# --- Structured Model Output ---
@dataclass
class CLOSPOutput(ModelOutput):
"""
Base class for CLOSP model's outputs.
"""
loss: torch.FloatTensor = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
logits_per_loc_img: torch.FloatTensor = None
logits_per_img_loc: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
location_embeds: torch.FloatTensor = None
class CLOSPModel(PreTrainedModel):
config_class = CLOSPConfig
def __init__(self, config: CLOSPConfig):
super().__init__(config)
# --- Vision Encoders ---
self.s1_encoder = create_model(
config.vision_model_key,
in_chans=2,
num_classes=config.s1_head_dim,
pretrained=False,
)
self.s2_encoder = create_model(
config.vision_model_key,
in_chans=13,
num_classes=config.s2_head_dim,
pretrained=False,
)
self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim)
self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim)
# --- Text Encoder ---
self.text_model = AutoModel.from_config(
AutoConfig.from_pretrained(config.text_model_name_or_path)
)
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
# --- Location Encoder ---
if config.use_location_encoder:
self.location_encoder = LocationEncoder(512, 2, 256, 10)
self.location_projection = nn.Linear(
config.location_embedding_dim, config.projection_dim
)
def tokenize_text(self, text: str):
"""Tokenizes input text using the model's tokenizer."""
return self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
"""Encodes an image tensor into features."""
image = image.float()
if image.shape[1] == 2: # Sentinel-1
image_features = self.s1_projection(self.s1_encoder(image))
else: # Sentinel-2
image_features = self.s2_projection(self.s2_encoder(image))
return F.normalize(image_features, p=2, dim=-1)
def get_text_features(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Encodes text tokens into features."""
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
text_features = text_outputs.last_hidden_state[:, 0, :]
return F.normalize(text_features, p=2, dim=-1)
def get_location_features(self, coords: torch.Tensor) -> torch.Tensor:
"""Encodes coordinates into features."""
if not self.config.use_location_encoder:
raise ValueError(
"Location encoder is not enabled for this model. Set `use_location_encoder=True` in config."
)
location_features = self.location_encoder(coords)
location_features = self.location_projection(location_features)
return F.normalize(location_features, p=2, dim=-1)
def forward(
self,
image: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
coords: torch.Tensor = None,
return_loss: bool = False,
) -> CLOSPOutput:
image_embeds = self.get_image_features(image)
text_embeds = self.get_text_features(input_ids, attention_mask)
# Cosine similarity as logits
logits_per_image = image_embeds @ text_embeds.T
logits_per_text = logits_per_image.T
# --- Optional Location Logic ---
location_embeds = None
logits_per_loc_img = None
logits_per_img_loc = None
if self.config.use_location_encoder:
if coords is None:
raise ValueError(
"Coordinates must be provided when use_location_encoder is True."
)
location_embeds = self.get_location_features(coords)
logits_per_loc_img = location_embeds @ image_embeds.T
logits_per_img_loc = image_embeds @ location_embeds.T
# --- Optional Loss Calculation ---
loss = None
if return_loss:
outputs = [
logits_per_image,
logits_per_text,
logits_per_loc_img,
logits_per_img_loc,
]
ground_truth = torch.arange(len(input_ids)).to(self.device)
loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None]
loss = sum(loss) / len(loss)
return CLOSPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
logits_per_loc_img=logits_per_loc_img,
logits_per_img_loc=logits_per_img_loc,
image_embeds=image_embeds,
text_embeds=text_embeds,
location_embeds=location_embeds,
)