marinapLC commited on
Commit
140f5bf
·
1 Parent(s): 5424c8e

Add NeoBERTForTokenClassification class

Browse files
Files changed (2) hide show
  1. config.json +23 -2
  2. model.py +70 -0
config.json CHANGED
@@ -6,7 +6,8 @@
6
  "AutoConfig": "model.NeoBERTConfig",
7
  "AutoModel": "model.NeoBERT",
8
  "AutoModelForMaskedLM": "model.NeoBERTLMHead",
9
- "AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
 
10
  },
11
  "classifier_init_range": 0.02,
12
  "decoder_init_range": 0.02,
@@ -15,8 +16,28 @@
15
  "hidden_size": 768,
16
  "intermediate_size": 3072,
17
  "kwargs": {
 
 
 
 
 
 
 
 
 
 
 
18
  "classifier_init_range": 0.02,
 
 
 
 
 
 
 
19
  "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
 
 
20
  "trust_remote_code": true
21
  },
22
  "max_length": 4096,
@@ -27,7 +48,7 @@
27
  "pad_token_id": 0,
28
  "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
29
  "torch_dtype": "float32",
30
- "transformers_version": "4.48.2",
31
  "trust_remote_code": true,
32
  "vocab_size": 30522
33
  }
 
6
  "AutoConfig": "model.NeoBERTConfig",
7
  "AutoModel": "model.NeoBERT",
8
  "AutoModelForMaskedLM": "model.NeoBERTLMHead",
9
+ "AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification",
10
+ "AutoModelForTokenClassification": "model.NeoBERTForTokenClassification"
11
  },
12
  "classifier_init_range": 0.02,
13
  "decoder_init_range": 0.02,
 
16
  "hidden_size": 768,
17
  "intermediate_size": 3072,
18
  "kwargs": {
19
+ "_commit_hash": null,
20
+ "architectures": [
21
+ "NeoBERTLMHead"
22
+ ],
23
+ "attn_implementation": null,
24
+ "auto_map": {
25
+ "AutoConfig": "model.NeoBERTConfig",
26
+ "AutoModel": "model.NeoBERT",
27
+ "AutoModelForMaskedLM": "model.NeoBERTLMHead",
28
+ "AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
29
+ },
30
  "classifier_init_range": 0.02,
31
+ "dim_head": 64,
32
+ "kwargs": {
33
+ "classifier_init_range": 0.02,
34
+ "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
35
+ "trust_remote_code": true
36
+ },
37
+ "model_type": "neobert",
38
  "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.48.2",
41
  "trust_remote_code": true
42
  },
43
  "max_length": 4096,
 
48
  "pad_token_id": 0,
49
  "pretrained_model_name_or_path": "google-bert/bert-base-uncased",
50
  "torch_dtype": "float32",
51
+ "transformers_version": "4.51.3",
52
  "trust_remote_code": true,
53
  "vocab_size": 30522
54
  }
model.py CHANGED
@@ -27,6 +27,7 @@ from transformers.modeling_outputs import (
27
  BaseModelOutput,
28
  MaskedLMOutput,
29
  SequenceClassifierOutput,
 
30
  )
31
 
32
  from .rotary import precompute_freqs_cis, apply_rotary_emb
@@ -432,3 +433,72 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
432
  hidden_states=output.hidden_states if output_hidden_states else None,
433
  attentions=output.attentions if output_attentions else None,
434
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  BaseModelOutput,
28
  MaskedLMOutput,
29
  SequenceClassifierOutput,
30
+ TokenClassifierOutput
31
  )
32
 
33
  from .rotary import precompute_freqs_cis, apply_rotary_emb
 
433
  hidden_states=output.hidden_states if output_hidden_states else None,
434
  attentions=output.attentions if output_attentions else None,
435
  )
436
+
437
+ class NeoBERTForTokenClassification(NeoBERTPreTrainedModel):
438
+ config_class = NeoBERTConfig
439
+
440
+ def __init__(self, config: NeoBERTConfig):
441
+ super().__init__(config)
442
+
443
+ self.config = config
444
+ self.num_labels = getattr(config, "num_labels", 2)
445
+ self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
446
+ self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
447
+
448
+ self.model = NeoBERT(config)
449
+
450
+ self.dropout = nn.Dropout(self.classifier_dropout)
451
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
452
+
453
+ self.post_init()
454
+
455
+ def _init_weights(self, module):
456
+ if isinstance(module, nn.Linear):
457
+ module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
458
+ if module.bias is not None:
459
+ module.bias.data.zero_()
460
+
461
+ def forward(
462
+ self,
463
+ input_ids: Optional[torch.Tensor] = None,
464
+ position_ids: Optional[torch.Tensor] = None,
465
+ max_seqlen: Optional[int] = None,
466
+ cu_seqlens: Optional[torch.Tensor] = None,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ output_hidden_states: Optional[bool] = False,
469
+ output_attentions: Optional[bool] = False,
470
+ labels: Optional[torch.Tensor] = None,
471
+ return_dict: Optional[bool] = None,
472
+ ):
473
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
474
+
475
+ output = self.model(
476
+ input_ids=input_ids,
477
+ position_ids=position_ids,
478
+ max_seqlen=max_seqlen,
479
+ cu_seqlens=cu_seqlens,
480
+ attention_mask=attention_mask,
481
+ output_hidden_states=output_hidden_states,
482
+ output_attentions=output_attentions,
483
+ )
484
+
485
+ sequence_output = output.last_hidden_state
486
+ sequence_output = self.dropout(sequence_output)
487
+ logits = self.classifier(sequence_output)
488
+
489
+ loss = None
490
+ if labels is not None:
491
+ loss_fct = CrossEntropyLoss()
492
+ # Reshape logits and labels to compute token classification loss
493
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,)
497
+ return ((loss,) + output) if loss is not None else output
498
+
499
+ return TokenClassifierOutput(
500
+ loss=loss,
501
+ logits=logits,
502
+ hidden_states=output.hidden_states if output_hidden_states else None,
503
+ attentions=output.attentions if output_attentions else None,
504
+ )