Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
from transformers import PreTrainedModel, AutoModel | |
from transformers import PretrainedConfig | |
class DualDistilBERTClassifier(PreTrainedModel): | |
config_class = PretrainedConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path) | |
hidden_size = self.encoder.config.hidden_size | |
self.classifier = nn.Sequential( | |
nn.Linear(hidden_size * 2, 512), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(512, self.num_labels) | |
) | |
def forward(self, title_input_ids, title_attention_mask, | |
abstract_input_ids, abstract_attention_mask, labels=None): | |
title_out = self.encoder(input_ids=title_input_ids, | |
attention_mask=title_attention_mask).last_hidden_state[:, 0] | |
abstract_out = self.encoder(input_ids=abstract_input_ids, | |
attention_mask=abstract_attention_mask).last_hidden_state[:, 0] | |
# Если абстракт пустой — можно заменить его на нули | |
abstract_out = torch.where( | |
abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0, | |
abstract_out, | |
torch.zeros_like(abstract_out) | |
) | |
combined = torch.cat([title_out, abstract_out], dim=1) | |
logits = self.classifier(combined) | |
return {'logits': logits} |