|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from timm import create_model |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoTokenizer, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
) |
|
from transformers.utils import ModelOutput |
|
|
|
from .location_encoder import LocationEncoder |
|
|
|
|
|
class CLOSPConfig(PretrainedConfig): |
|
""" |
|
Configuration class for CLOSPModel. |
|
|
|
This class stores the configuration of a CLOSPModel, which is used to instantiate the model |
|
according to the specified parameters. |
|
""" |
|
|
|
model_type = "closp" |
|
|
|
def __init__( |
|
self, |
|
|
|
vision_model_key: str = "vit-s", |
|
s1_embedding_dim: int = 384, |
|
s2_embedding_dim: int = 384, |
|
s1_head_dim: int = 0, |
|
s2_head_dim: int = 0, |
|
|
|
text_model_name_or_path: str = "distilbert-base-uncased", |
|
|
|
use_location_encoder: bool = True, |
|
location_embedding_dim: int = 512, |
|
|
|
projection_dim: int = 768, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.vision_model_key = vision_model_key |
|
self.s1_embedding_dim = s1_embedding_dim |
|
self.s2_embedding_dim = s2_embedding_dim |
|
self.text_model_name_or_path = text_model_name_or_path |
|
self.use_location_encoder = use_location_encoder |
|
self.location_embedding_dim = location_embedding_dim |
|
self.projection_dim = projection_dim |
|
self.s1_head_dim = s1_head_dim |
|
self.s2_head_dim = s2_head_dim |
|
|
|
|
|
|
|
@dataclass |
|
class CLOSPOutput(ModelOutput): |
|
""" |
|
Base class for CLOSP model's outputs. |
|
""" |
|
|
|
loss: torch.FloatTensor = None |
|
logits_per_image: torch.FloatTensor = None |
|
logits_per_text: torch.FloatTensor = None |
|
logits_per_loc_img: torch.FloatTensor = None |
|
logits_per_img_loc: torch.FloatTensor = None |
|
image_embeds: torch.FloatTensor = None |
|
text_embeds: torch.FloatTensor = None |
|
location_embeds: torch.FloatTensor = None |
|
|
|
|
|
class CLOSPModel(PreTrainedModel): |
|
config_class = CLOSPConfig |
|
|
|
def __init__(self, config: CLOSPConfig): |
|
super().__init__(config) |
|
|
|
self.s1_encoder = create_model( |
|
config.vision_model_key, |
|
in_chans=2, |
|
num_classes=config.s1_head_dim, |
|
pretrained=False, |
|
) |
|
self.s2_encoder = create_model( |
|
config.vision_model_key, |
|
in_chans=13, |
|
num_classes=config.s2_head_dim, |
|
pretrained=False, |
|
) |
|
self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) |
|
self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) |
|
|
|
|
|
self.text_model = AutoModel.from_config( |
|
AutoConfig.from_pretrained(config.text_model_name_or_path) |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) |
|
|
|
|
|
if config.use_location_encoder: |
|
self.location_encoder = LocationEncoder(512, 2, 256, 10) |
|
self.location_projection = nn.Linear( |
|
config.location_embedding_dim, config.projection_dim |
|
) |
|
|
|
def tokenize_text(self, text: str): |
|
"""Tokenizes input text using the model's tokenizer.""" |
|
return self.tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="pt", |
|
) |
|
|
|
def get_image_features(self, image: torch.Tensor) -> torch.Tensor: |
|
"""Encodes an image tensor into features.""" |
|
image = image.float() |
|
if image.shape[1] == 2: |
|
image_features = self.s1_projection(self.s1_encoder(image)) |
|
else: |
|
image_features = self.s2_projection(self.s2_encoder(image)) |
|
|
|
return F.normalize(image_features, p=2, dim=-1) |
|
|
|
def get_text_features( |
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Encodes text tokens into features.""" |
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
) |
|
text_features = text_outputs.last_hidden_state[:, 0, :] |
|
return F.normalize(text_features, p=2, dim=-1) |
|
|
|
def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: |
|
"""Encodes coordinates into features.""" |
|
if not self.config.use_location_encoder: |
|
raise ValueError( |
|
"Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." |
|
) |
|
location_features = self.location_encoder(coords) |
|
location_features = self.location_projection(location_features) |
|
return F.normalize(location_features, p=2, dim=-1) |
|
|
|
def forward( |
|
self, |
|
image: torch.Tensor, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
coords: torch.Tensor = None, |
|
return_loss: bool = False, |
|
) -> CLOSPOutput: |
|
image_embeds = self.get_image_features(image) |
|
text_embeds = self.get_text_features(input_ids, attention_mask) |
|
|
|
|
|
logits_per_image = image_embeds @ text_embeds.T |
|
logits_per_text = logits_per_image.T |
|
|
|
|
|
location_embeds = None |
|
logits_per_loc_img = None |
|
logits_per_img_loc = None |
|
|
|
if self.config.use_location_encoder: |
|
if coords is None: |
|
raise ValueError( |
|
"Coordinates must be provided when use_location_encoder is True." |
|
) |
|
location_embeds = self.get_location_features(coords) |
|
logits_per_loc_img = location_embeds @ image_embeds.T |
|
logits_per_img_loc = image_embeds @ location_embeds.T |
|
|
|
|
|
loss = None |
|
if return_loss: |
|
outputs = [ |
|
logits_per_image, |
|
logits_per_text, |
|
logits_per_loc_img, |
|
logits_per_img_loc, |
|
] |
|
ground_truth = torch.arange(len(input_ids)).to(self.device) |
|
loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] |
|
loss = sum(loss) / len(loss) |
|
|
|
return CLOSPOutput( |
|
loss=loss, |
|
logits_per_image=logits_per_image, |
|
logits_per_text=logits_per_text, |
|
logits_per_loc_img=logits_per_loc_img, |
|
logits_per_img_loc=logits_per_img_loc, |
|
image_embeds=image_embeds, |
|
text_embeds=text_embeds, |
|
location_embeds=location_embeds, |
|
) |
|
|