import torch.nn as nn import torch from transformers import PreTrainedModel, AutoModel from transformers import PretrainedConfig class DualDistilBERTClassifier(PreTrainedModel): config_class = PretrainedConfig def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path) hidden_size = self.encoder.config.hidden_size self.classifier = nn.Sequential( nn.Linear(hidden_size * 2, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.num_labels) ) def forward(self, title_input_ids, title_attention_mask, abstract_input_ids, abstract_attention_mask, labels=None): title_out = self.encoder(input_ids=title_input_ids, attention_mask=title_attention_mask).last_hidden_state[:, 0] abstract_out = self.encoder(input_ids=abstract_input_ids, attention_mask=abstract_attention_mask).last_hidden_state[:, 0] # Если абстракт пустой — можно заменить его на нули abstract_out = torch.where( abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0, abstract_out, torch.zeros_like(abstract_out) ) combined = torch.cat([title_out, abstract_out], dim=1) logits = self.classifier(combined) return {'logits': logits}