Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -1,39 +1,47 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
)
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
return {'logits': logits}
|
|
|
1 |
+
from transformers.utils import (
|
2 |
+
logging,
|
3 |
+
is_torch_available,
|
4 |
+
)
|
5 |
+
from transformers.modeling_utils import (
|
6 |
+
no_init_weights,
|
7 |
+
init_empty_weights,
|
8 |
+
)
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
from transformers import PreTrainedModel, AutoModel
|
12 |
+
from transformers import PretrainedConfig
|
13 |
+
|
14 |
+
class DualDistilBERTClassifier(PreTrainedModel):
|
15 |
+
config_class = PretrainedConfig
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__(config)
|
18 |
+
self.num_labels = config.num_labels
|
19 |
+
self.encoder = AutoModel.from_pretrained(config.base_model_name_or_path)
|
20 |
+
hidden_size = self.encoder.config.hidden_size
|
21 |
+
|
22 |
+
self.classifier = nn.Sequential(
|
23 |
+
nn.Linear(hidden_size * 2, 512),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Dropout(0.1),
|
26 |
+
nn.Linear(512, self.num_labels)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, title_input_ids, title_attention_mask,
|
30 |
+
abstract_input_ids, abstract_attention_mask, labels=None):
|
31 |
+
|
32 |
+
title_out = self.encoder(input_ids=title_input_ids,
|
33 |
+
attention_mask=title_attention_mask).last_hidden_state[:, 0]
|
34 |
+
|
35 |
+
abstract_out = self.encoder(input_ids=abstract_input_ids,
|
36 |
+
attention_mask=abstract_attention_mask).last_hidden_state[:, 0]
|
37 |
+
|
38 |
+
# Если абстракт пустой — можно заменить его на нули
|
39 |
+
abstract_out = torch.where(
|
40 |
+
abstract_attention_mask.sum(dim=1).unsqueeze(1) > 0,
|
41 |
+
abstract_out,
|
42 |
+
torch.zeros_like(abstract_out)
|
43 |
+
)
|
44 |
+
|
45 |
+
combined = torch.cat([title_out, abstract_out], dim=1)
|
46 |
+
logits = self.classifier(combined)
|
47 |
return {'logits': logits}
|