Add NeoBERTForTokenClassification class
Browse files- config.json +23 -2
- model.py +70 -0
config.json
CHANGED
@@ -6,7 +6,8 @@
|
|
6 |
"AutoConfig": "model.NeoBERTConfig",
|
7 |
"AutoModel": "model.NeoBERT",
|
8 |
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
9 |
-
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
|
|
|
10 |
},
|
11 |
"classifier_init_range": 0.02,
|
12 |
"decoder_init_range": 0.02,
|
@@ -15,8 +16,28 @@
|
|
15 |
"hidden_size": 768,
|
16 |
"intermediate_size": 3072,
|
17 |
"kwargs": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"classifier_init_range": 0.02,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
|
|
|
|
20 |
"trust_remote_code": true
|
21 |
},
|
22 |
"max_length": 4096,
|
@@ -27,7 +48,7 @@
|
|
27 |
"pad_token_id": 0,
|
28 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
29 |
"torch_dtype": "float32",
|
30 |
-
"transformers_version": "4.
|
31 |
"trust_remote_code": true,
|
32 |
"vocab_size": 30522
|
33 |
}
|
|
|
6 |
"AutoConfig": "model.NeoBERTConfig",
|
7 |
"AutoModel": "model.NeoBERT",
|
8 |
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
9 |
+
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification",
|
10 |
+
"AutoModelForTokenClassification": "model.NeoBERTForTokenClassification"
|
11 |
},
|
12 |
"classifier_init_range": 0.02,
|
13 |
"decoder_init_range": 0.02,
|
|
|
16 |
"hidden_size": 768,
|
17 |
"intermediate_size": 3072,
|
18 |
"kwargs": {
|
19 |
+
"_commit_hash": null,
|
20 |
+
"architectures": [
|
21 |
+
"NeoBERTLMHead"
|
22 |
+
],
|
23 |
+
"attn_implementation": null,
|
24 |
+
"auto_map": {
|
25 |
+
"AutoConfig": "model.NeoBERTConfig",
|
26 |
+
"AutoModel": "model.NeoBERT",
|
27 |
+
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
28 |
+
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
|
29 |
+
},
|
30 |
"classifier_init_range": 0.02,
|
31 |
+
"dim_head": 64,
|
32 |
+
"kwargs": {
|
33 |
+
"classifier_init_range": 0.02,
|
34 |
+
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
35 |
+
"trust_remote_code": true
|
36 |
+
},
|
37 |
+
"model_type": "neobert",
|
38 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
39 |
+
"torch_dtype": "float32",
|
40 |
+
"transformers_version": "4.48.2",
|
41 |
"trust_remote_code": true
|
42 |
},
|
43 |
"max_length": 4096,
|
|
|
48 |
"pad_token_id": 0,
|
49 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
50 |
"torch_dtype": "float32",
|
51 |
+
"transformers_version": "4.51.3",
|
52 |
"trust_remote_code": true,
|
53 |
"vocab_size": 30522
|
54 |
}
|
model.py
CHANGED
@@ -27,6 +27,7 @@ from transformers.modeling_outputs import (
|
|
27 |
BaseModelOutput,
|
28 |
MaskedLMOutput,
|
29 |
SequenceClassifierOutput,
|
|
|
30 |
)
|
31 |
|
32 |
from .rotary import precompute_freqs_cis, apply_rotary_emb
|
@@ -432,3 +433,72 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
|
|
432 |
hidden_states=output.hidden_states if output_hidden_states else None,
|
433 |
attentions=output.attentions if output_attentions else None,
|
434 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
BaseModelOutput,
|
28 |
MaskedLMOutput,
|
29 |
SequenceClassifierOutput,
|
30 |
+
TokenClassifierOutput
|
31 |
)
|
32 |
|
33 |
from .rotary import precompute_freqs_cis, apply_rotary_emb
|
|
|
433 |
hidden_states=output.hidden_states if output_hidden_states else None,
|
434 |
attentions=output.attentions if output_attentions else None,
|
435 |
)
|
436 |
+
|
437 |
+
class NeoBERTForTokenClassification(NeoBERTPreTrainedModel):
|
438 |
+
config_class = NeoBERTConfig
|
439 |
+
|
440 |
+
def __init__(self, config: NeoBERTConfig):
|
441 |
+
super().__init__(config)
|
442 |
+
|
443 |
+
self.config = config
|
444 |
+
self.num_labels = getattr(config, "num_labels", 2)
|
445 |
+
self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
|
446 |
+
self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
|
447 |
+
|
448 |
+
self.model = NeoBERT(config)
|
449 |
+
|
450 |
+
self.dropout = nn.Dropout(self.classifier_dropout)
|
451 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
452 |
+
|
453 |
+
self.post_init()
|
454 |
+
|
455 |
+
def _init_weights(self, module):
|
456 |
+
if isinstance(module, nn.Linear):
|
457 |
+
module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
|
458 |
+
if module.bias is not None:
|
459 |
+
module.bias.data.zero_()
|
460 |
+
|
461 |
+
def forward(
|
462 |
+
self,
|
463 |
+
input_ids: Optional[torch.Tensor] = None,
|
464 |
+
position_ids: Optional[torch.Tensor] = None,
|
465 |
+
max_seqlen: Optional[int] = None,
|
466 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
467 |
+
attention_mask: Optional[torch.Tensor] = None,
|
468 |
+
output_hidden_states: Optional[bool] = False,
|
469 |
+
output_attentions: Optional[bool] = False,
|
470 |
+
labels: Optional[torch.Tensor] = None,
|
471 |
+
return_dict: Optional[bool] = None,
|
472 |
+
):
|
473 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
474 |
+
|
475 |
+
output = self.model(
|
476 |
+
input_ids=input_ids,
|
477 |
+
position_ids=position_ids,
|
478 |
+
max_seqlen=max_seqlen,
|
479 |
+
cu_seqlens=cu_seqlens,
|
480 |
+
attention_mask=attention_mask,
|
481 |
+
output_hidden_states=output_hidden_states,
|
482 |
+
output_attentions=output_attentions,
|
483 |
+
)
|
484 |
+
|
485 |
+
sequence_output = output.last_hidden_state
|
486 |
+
sequence_output = self.dropout(sequence_output)
|
487 |
+
logits = self.classifier(sequence_output)
|
488 |
+
|
489 |
+
loss = None
|
490 |
+
if labels is not None:
|
491 |
+
loss_fct = CrossEntropyLoss()
|
492 |
+
# Reshape logits and labels to compute token classification loss
|
493 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
494 |
+
|
495 |
+
if not return_dict:
|
496 |
+
output = (logits,)
|
497 |
+
return ((loss,) + output) if loss is not None else output
|
498 |
+
|
499 |
+
return TokenClassifierOutput(
|
500 |
+
loss=loss,
|
501 |
+
logits=logits,
|
502 |
+
hidden_states=output.hidden_states if output_hidden_states else None,
|
503 |
+
attentions=output.attentions if output_attentions else None,
|
504 |
+
)
|