DeniSSio commited on
Commit
d08d331
·
verified ·
1 Parent(s): 9f12fee

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +46 -38
model.py CHANGED
@@ -1,39 +1,47 @@
1
- import torch.nn as nn
2
- import torch
3
- from transformers import PreTrainedModel, AutoModel
4
- from transformers import PretrainedConfig
5
-
6
- class DualDistilBERTClassifier(PreTrainedModel):
7
- config_class = PretrainedConfig
8
- def __init__(self, config):
9
- super().__init__(config)
10
- self.num_labels = config.num_labels
11
- self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path)
12
- hidden_size = self.encoder.config.hidden_size
13
-
14
- self.classifier = nn.Sequential(
15
- nn.Linear(hidden_size * 2, 512),
16
- nn.ReLU(),
17
- nn.Dropout(0.1),
18
- nn.Linear(512, self.num_labels)
19
- )
20
-
21
- def forward(self, title_input_ids, title_attention_mask,
22
- abstract_input_ids, abstract_attention_mask, labels=None):
23
-
24
- title_out = self.encoder(input_ids=title_input_ids,
25
- attention_mask=title_attention_mask).last_hidden_state[:, 0]
26
-
27
- abstract_out = self.encoder(input_ids=abstract_input_ids,
28
- attention_mask=abstract_attention_mask).last_hidden_state[:, 0]
29
-
30
- # Если абстракт пустой — можно заменить его на нули
31
- abstract_out = torch.where(
32
- abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0,
33
- abstract_out,
34
- torch.zeros_like(abstract_out)
35
- )
36
-
37
- combined = torch.cat([title_out, abstract_out], dim=1)
38
- logits = self.classifier(combined)
 
 
 
 
 
 
 
 
39
  return {'logits': logits}
 
1
+ from transformers.utils import (
2
+ logging,
3
+ is_torch_available,
4
+ )
5
+ from transformers.modeling_utils import (
6
+ no_init_weights,
7
+ init_empty_weights,
8
+ )
9
+ import torch.nn as nn
10
+ import torch
11
+ from transformers import PreTrainedModel, AutoModel
12
+ from transformers import PretrainedConfig
13
+
14
+ class DualDistilBERTClassifier(PreTrainedModel):
15
+ config_class = PretrainedConfig
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ self.num_labels = config.num_labels
19
+ self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path)
20
+ hidden_size = self.encoder.config.hidden_size
21
+
22
+ self.classifier = nn.Sequential(
23
+ nn.Linear(hidden_size * 2, 512),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.1),
26
+ nn.Linear(512, self.num_labels)
27
+ )
28
+
29
+ def forward(self, title_input_ids, title_attention_mask,
30
+ abstract_input_ids, abstract_attention_mask, labels=None):
31
+
32
+ title_out = self.encoder(input_ids=title_input_ids,
33
+ attention_mask=title_attention_mask).last_hidden_state[:, 0]
34
+
35
+ abstract_out = self.encoder(input_ids=abstract_input_ids,
36
+ attention_mask=abstract_attention_mask).last_hidden_state[:, 0]
37
+
38
+ # Если абстракт пустой — можно заменить его на нули
39
+ abstract_out = torch.where(
40
+ abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0,
41
+ abstract_out,
42
+ torch.zeros_like(abstract_out)
43
+ )
44
+
45
+ combined = torch.cat([title_out, abstract_out], dim=1)
46
+ logits = self.classifier(combined)
47
  return {'logits': logits}