File size: 1,534 Bytes
d08d331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f12fee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
import torch
from transformers import PreTrainedModel, AutoModel
from transformers import PretrainedConfig

class DualDistilBERTClassifier(PreTrainedModel):
    config_class = PretrainedConfig
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path)
        hidden_size = self.encoder.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, self.num_labels)
        )

    def forward(self, title_input_ids, title_attention_mask,
                abstract_input_ids, abstract_attention_mask, labels=None):

        title_out = self.encoder(input_ids=title_input_ids,
                                 attention_mask=title_attention_mask).last_hidden_state[:, 0]

        abstract_out = self.encoder(input_ids=abstract_input_ids,
                                    attention_mask=abstract_attention_mask).last_hidden_state[:, 0]

        # Если абстракт пустой — можно заменить его на нули
        abstract_out = torch.where(
            abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0,
            abstract_out,
            torch.zeros_like(abstract_out)
        )

        combined = torch.cat([title_out, abstract_out], dim=1)
        logits = self.classifier(combined)
        return {'logits': logits}