ccdv commited on
Commit
68a7cc0
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: fr
3
+ tags:
4
+ - long context
5
+ pipeline_tag: fill-mask
6
+ ---
7
+
8
+ # LSG model
9
+ **Transformers >= 4.18.0**\
10
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
11
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
12
+
13
+ * [Usage](#usage)
14
+ * [Parameters](#parameters)
15
+ * [Sparse selection type](#sparse-selection-type)
16
+ * [Tasks](#tasks)
17
+ * [Training global tokens](#training-global-tokens)
18
+
19
+ This model is a small version of the [distilcamembert-base](https://huggingface.co/cmarkea/distilcamembert-base) model without additional pretraining yet. It uses the same number of parameters/layers and the same tokenizer.
20
+
21
+
22
+ This model can handle long sequences but faster and more efficiently than Longformer or BigBird (from Transformers) and relies on Local + Sparse + Global attention (LSG).
23
+
24
+
25
+ The model requires sequences whose length is a multiple of the block size. The model is "adaptive" and automatically pads the sequences if needed (adaptive=True in config). It is however recommended, thanks to the tokenizer, to truncate the inputs (truncation=True) and optionally to pad with a multiple of the block size (pad_to_multiple_of=...). \
26
+
27
+
28
+ Support encoder-decoder but I didnt test it extensively.\
29
+ Implemented in PyTorch.
30
+
31
+ ![attn](attn.png)
32
+
33
+ ## Usage
34
+ The model relies on a custom modeling file, you need to add trust_remote_code=True to use it.
35
+
36
+ ```python:
37
+ from transformers import AutoModel, AutoTokenizer
38
+
39
+ model = AutoModel.from_pretrained("ccdv/lsg-distilcamembert-base-4096", trust_remote_code=True)
40
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilcamembert-base-4096")
41
+ ```
42
+
43
+ ## Parameters
44
+ You can change various parameters like :
45
+ * the number of global tokens (num_global_tokens=1)
46
+ * local block size (block_size=128)
47
+ * sparse block size (sparse_block_size=128)
48
+ * sparsity factor (sparsity_factor=2)
49
+ * see config.json file
50
+
51
+ Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
52
+
53
+ ```python:
54
+ model = AutoModel.from_pretrained("ccdv/lsg-distilcamembert-base-4096",
55
+ trust_remote_code=True,
56
+ num_global_tokens=16,
57
+ block_size=64,
58
+ sparse_block_size=64,
59
+ sparsity_factor=4,
60
+ attention_probs_dropout_prob=0.0
61
+ )
62
+ ```
63
+
64
+ ## Sparse selection type
65
+
66
+ There are 5 different sparse selection patterns. The best type is task dependent. \
67
+ Note that for sequences with length < 2*block_size, the type has no effect.
68
+
69
+ * sparsity_type="norm", select highest norm tokens
70
+ * Works best for a small sparsity_factor (2 to 4)
71
+ * Additional parameters:
72
+ * None
73
+ * sparsity_type="pooling", use average pooling to merge tokens
74
+ * Works best for a small sparsity_factor (2 to 4)
75
+ * Additional parameters:
76
+ * None
77
+ * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
78
+ * Works best for a large sparsity_factor (4+)
79
+ * LSH relies on random projections, thus inference may differ slightly with different seeds
80
+ * Additional parameters:
81
+ * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
82
+ * sparsity_type="stride", use a striding mecanism per head
83
+ * Each head will use different tokens strided by sparsify_factor
84
+ * Not recommended if sparsify_factor > num_heads
85
+ * sparsity_type="block_stride", use a striding mecanism per head
86
+ * Each head will use block of tokens strided by sparsify_factor
87
+ * Not recommended if sparsify_factor > num_heads
88
+
89
+ ## Tasks
90
+ Fill mask example:
91
+ ```python:
92
+ from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
93
+
94
+ model = AutoModelForMaskedLM.from_pretrained("ccdv/lsg-distilcamembert-base-4096", trust_remote_code=True)
95
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilcamembert-base-4096")
96
+
97
+ SENTENCES = ["Paris is the <mask> of France.", "The goal of life is <mask>."]
98
+ pipeline = FillMaskPipeline(model, tokenizer)
99
+ output = pipeline(SENTENCES, top_k=1)
100
+
101
+ output = [o[0]["sequence"] for o in output]
102
+ > ['Paris is the capital of France.', 'The goal of life is happiness.']
103
+ ```
104
+
105
+
106
+ Classification example:
107
+ ```python:
108
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
109
+
110
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-distilcamembert-base-4096",
111
+ trust_remote_code=True,
112
+ pool_with_global=True, # pool with a global token instead of first token
113
+ )
114
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilcamembert-base-4096")
115
+
116
+ SENTENCE = "This is a test for sequence classification. " * 300
117
+ token_ids = tokenizer(
118
+ SENTENCE,
119
+ return_tensors="pt",
120
+ #pad_to_multiple_of=... # Optional
121
+ truncation=True
122
+ )
123
+ output = model(**token_ids)
124
+
125
+ > SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
126
+ ```
127
+
128
+ ## Training global tokens
129
+ To train global tokens and the classification head only:
130
+ ```python:
131
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
132
+
133
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-distilcamembert-base-4096",
134
+ trust_remote_code=True,
135
+ pool_with_global=True, # pool with a global token instead of first token
136
+ num_global_tokens=16
137
+ )
138
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-distilcamembert-base-4096")
139
+
140
+ for name, param in model.named_parameters():
141
+ if "global_embeddings" not in name:
142
+ param.requires_grad = False
143
+ else:
144
+ param.required_grad = True
145
+ ```
attn.png ADDED
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ccdv/lsg-distilcamembert-base-4096",
3
+ "adaptive": true,
4
+ "architectures": [
5
+ "LSGCamembertForMaskedLM"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_lsg_camembert.LSGCamembertConfig",
10
+ "AutoModel": "modeling_lsg_camembert.LSGCamembertModel",
11
+ "AutoModelForCausalLM": "modeling_lsg_camembert.LSGCamembertForCausalLM",
12
+ "AutoModelForMaskedLM": "modeling_lsg_camembert.LSGCamembertForMaskedLM",
13
+ "AutoModelForMultipleChoice": "modeling_lsg_camembert.LSGCamembertForMultipleChoice",
14
+ "AutoModelForQuestionAnswering": "modeling_lsg_camembert.LSGCamembertForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_lsg_camembert.LSGCamembertForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_lsg_camembert.LSGCamembertForTokenClassification"
17
+ },
18
+ "base_model_prefix": "lsg",
19
+ "block_size": 128,
20
+ "bos_token_id": 0,
21
+ "classifier_dropout": null,
22
+ "eos_token_id": 2,
23
+ "gradient_checkpointing": false,
24
+ "hidden_act": "gelu",
25
+ "hidden_dropout_prob": 0.1,
26
+ "hidden_size": 768,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 3072,
29
+ "layer_norm_eps": 1e-05,
30
+ "lsh_num_pre_rounds": 1,
31
+ "max_position_embeddings": 4098,
32
+ "model_type": "camembert",
33
+ "num_attention_heads": 12,
34
+ "num_global_tokens": 1,
35
+ "num_hidden_layers": 6,
36
+ "pad_token_id": 1,
37
+ "pool_with_global": true,
38
+ "position_embedding_type": "absolute",
39
+ "sparse_block_size": 128,
40
+ "sparsity_factor": 2,
41
+ "sparsity_type": "norm",
42
+ "torch_dtype": "float32",
43
+ "transformers_version": "4.19.2",
44
+ "type_vocab_size": 1,
45
+ "use_cache": true,
46
+ "vocab_size": 32005
47
+ }
modeling_lsg_camembert.py ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ from transformers.models.roberta.modeling_roberta import *
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.camembert.configuration_camembert import CamembertConfig
6
+ import sys
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_lsg_camembert.LSGCamembertModel",
10
+ "AutoModelForCausalLM": "modeling_lsg_camembert.LSGCamembertForCausalLM",
11
+ "AutoModelForMaskedLM": "modeling_lsg_camembert.LSGCamembertForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_lsg_camembert.LSGCamembertForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_lsg_camembert.LSGCamembertForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_lsg_camembert.LSGCamembertForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_lsg_camembert.LSGCamembertForTokenClassification"
16
+ }
17
+
18
+ class LSGCamembertConfig(CamembertConfig):
19
+ """
20
+ This class overrides :class:`~transformers.CamembertConfig`. Please check the superclass for the appropriate
21
+ documentation alongside usage examples.
22
+ """
23
+
24
+ base_model_prefix = "lsg"
25
+ model_type = "camembert"
26
+
27
+ def __init__(
28
+ self,
29
+ adaptive=True,
30
+ base_model_prefix="lsg",
31
+ block_size=128,
32
+ lsh_num_pre_rounds=1,
33
+ num_global_tokens=1,
34
+ pool_with_global=True,
35
+ sparse_block_size=128,
36
+ sparsity_factor=2,
37
+ sparsity_type="norm",
38
+ **kwargs
39
+ ):
40
+ """Constructs LSGCamembertConfig."""
41
+ super().__init__(**kwargs)
42
+
43
+ self.adaptive = adaptive
44
+ self.auto_map = AUTO_MAP
45
+ self.base_model_prefix = base_model_prefix
46
+ self.block_size = block_size
47
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
48
+ self.num_global_tokens = num_global_tokens
49
+ self.pool_with_global = pool_with_global
50
+ self.sparse_block_size = sparse_block_size
51
+ self.sparsity_factor = sparsity_factor
52
+ self.sparsity_type = sparsity_type
53
+
54
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
55
+ logger.warning(
56
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
57
+ self.sparsity_type = None
58
+
59
+ if self.sparsity_type in ["stride", "block_stride"]:
60
+ if self.sparsity_factor > self.encoder_attention_heads:
61
+ logger.warning(
62
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
63
+ )
64
+
65
+ if self.num_global_tokens < 1:
66
+ logger.warning(
67
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
68
+ )
69
+ self.num_global_tokens = 1
70
+ elif self.num_global_tokens > 512:
71
+ logger.warning(
72
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
73
+ )
74
+ self.num_global_tokens = 512
75
+
76
+ if self.sparsity_factor > 0:
77
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
78
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
79
+
80
+
81
+ class BaseSelfAttention(nn.Module):
82
+
83
+ def init_modules(self, config):
84
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
85
+ config, "embedding_size"
86
+ ):
87
+ raise ValueError(
88
+ "The hidden size (%d) is not a multiple of the number of attention "
89
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
90
+ )
91
+
92
+ self.num_attention_heads = config.num_attention_heads
93
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
94
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
95
+
96
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
97
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
98
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
99
+
100
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
101
+
102
+ def transpose_for_scores(self, x):
103
+ new_x_shape = x.size()[:-1] + (
104
+ self.num_attention_heads,
105
+ self.attention_head_size,
106
+ )
107
+ x = x.view(*new_x_shape)
108
+ return x.permute(0, 2, 1, 3)
109
+
110
+ def reshape_output(self, context_layer):
111
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
112
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
113
+ return context_layer.view(*new_context_layer_shape)
114
+
115
+ def project_QKV(self, hidden_states):
116
+
117
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
118
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
119
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
120
+ return query_layer, key_layer, value_layer
121
+
122
+
123
+ class BaseAttentionProduct(nn.Module):
124
+
125
+ def __init__(self, config):
126
+ """
127
+ Compute attention: softmax(Q @ K.T) @ V
128
+ """
129
+ super().__init__()
130
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
131
+
132
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
133
+
134
+ d = query_layer.shape[-1]
135
+
136
+ # Take the dot product between "query" and "key" to get the raw attention scores.
137
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
138
+
139
+ del query_layer
140
+ del key_layer
141
+
142
+ if attention_mask is not None:
143
+ # Apply the attention mask is (precomputed for all layers in CamembertModel forward() function)
144
+ attention_scores = attention_scores + attention_mask
145
+ del attention_mask
146
+
147
+ # Normalize the attention scores to probabilities.
148
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
149
+
150
+ # This is actually dropping out entire tokens to attend to, which might
151
+ # seem a bit unusual, but is taken from the original Transformer paper.
152
+ context_layer = self.dropout(attention_probs) @ value_layer
153
+
154
+ return context_layer
155
+
156
+
157
+ class CausalAttentionProduct(nn.Module):
158
+
159
+ def __init__(self, config):
160
+ """
161
+ Compute attention: softmax(Q @ K.T) @ V
162
+ """
163
+ super().__init__()
164
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
165
+ self.block_size = config.block_size
166
+
167
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
168
+
169
+ d = query_layer.shape[-1]
170
+
171
+ # Take the dot product between "query" and "key" to get the raw attention scores.
172
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
173
+
174
+ del query_layer
175
+ del key_layer
176
+
177
+ if attention_mask is not None:
178
+ # Apply the attention mask is (precomputed for all layers in CamembertModel forward() function)
179
+ attention_scores = attention_scores + attention_mask
180
+
181
+ # Add causal mask
182
+ causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
183
+ causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
184
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
185
+
186
+ del attention_mask
187
+
188
+ # Normalize the attention scores to probabilities.
189
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
190
+
191
+ # This is actually dropping out entire tokens to attend to, which might
192
+ # seem a bit unusual, but is taken from the original Transformer paper.
193
+ context_layer = self.dropout(attention_probs) @ value_layer
194
+
195
+ return context_layer
196
+
197
+
198
+ class LSGAttentionProduct(nn.Module):
199
+
200
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
201
+ """
202
+ Compute block or overlapping blocks attention products
203
+ """
204
+ super().__init__()
205
+
206
+ self.block_size = block_size
207
+ self.sparse_block_size = sparse_block_size
208
+ self.sparsity_factor = sparsity_factor
209
+ self.is_causal = is_causal
210
+
211
+ if self.block_size is None:
212
+ self.block_size = config.block_size
213
+
214
+ if self.sparse_block_size is None:
215
+ self.sparse_block_size = config.sparse_block_size
216
+
217
+ # Shape of blocks
218
+ self.local_shapes = (self.block_size*3, self.block_size)
219
+ if self.sparse_block_size and self.sparsity_factor > 0:
220
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
221
+
222
+ if is_causal:
223
+ self.attention = CausalAttentionProduct(config)
224
+ else:
225
+ self.attention = BaseAttentionProduct(config)
226
+
227
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
228
+
229
+ # Build local tokens
230
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
231
+ del hidden_states
232
+
233
+ # Build sparse tokens
234
+ if sparse_hidden_states is not None:
235
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
236
+
237
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
238
+
239
+ def forward(
240
+ self,
241
+ query_layer,
242
+ key_layer,
243
+ value_layer,
244
+ attention_mask=None,
245
+ sparse_key=None,
246
+ sparse_value=None,
247
+ sparse_mask=None,
248
+ global_key=None,
249
+ global_value=None,
250
+ global_mask=None
251
+ ):
252
+
253
+ # Input batch, heads, length, hidden_size
254
+ n, h, t, d = query_layer.size()
255
+ n_blocks = t // self.block_size
256
+ assert t % self.block_size == 0
257
+
258
+ key_layer = self.build_lsg_inputs(
259
+ key_layer,
260
+ sparse_key,
261
+ global_key
262
+ )
263
+ del sparse_key
264
+ del global_key
265
+
266
+ value_layer = self.build_lsg_inputs(
267
+ value_layer,
268
+ sparse_value,
269
+ global_value
270
+ )
271
+ del sparse_value
272
+ del global_value
273
+
274
+ attention_mask = self.build_lsg_inputs(
275
+ attention_mask,
276
+ sparse_mask,
277
+ global_mask.transpose(-1, -2),
278
+ is_attn_mask=True
279
+ ).transpose(-1, -2)
280
+ del sparse_mask
281
+ del global_mask
282
+
283
+ # expect (..., t, d) shape
284
+ # Compute attention
285
+ context_layer = self.attention(
286
+ query_layer=self.chunk(query_layer, n_blocks),
287
+ key_layer=key_layer,
288
+ value_layer=value_layer,
289
+ attention_mask=attention_mask
290
+ )
291
+
292
+ return context_layer.reshape(n, h, -1, d)
293
+
294
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
295
+
296
+ size, step = self.local_shapes
297
+ s = (size - step) // 2
298
+
299
+ # Pad before block reshaping
300
+ if is_attn_mask:
301
+ pad_value = -10000
302
+ hidden_states = hidden_states.transpose(-1, -2)
303
+ else:
304
+ pad_value = 0
305
+
306
+ hidden_states = torch.nn.functional.pad(
307
+ hidden_states.transpose(-1, -2),
308
+ pad=(s, s),
309
+ value=pad_value
310
+ ).transpose(-1, -2)
311
+
312
+ # Make blocks
313
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
314
+
315
+ # Skip third block if causal
316
+ if self.is_causal:
317
+ return hidden_states[..., :size*2//3, :]
318
+
319
+ return hidden_states
320
+
321
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
322
+
323
+ size, step = self.sparse_shapes
324
+
325
+ # In case of odd case
326
+ odd_offset = (step % 2)
327
+
328
+ # n, h, t, d*2 + 1
329
+ size = size*2
330
+ s = (size - step) // 2 + odd_offset
331
+
332
+ # Pad before block reshaping
333
+ if is_attn_mask:
334
+ pad_value = -10000
335
+ hidden_states = hidden_states.transpose(-1, -2)
336
+ else:
337
+ pad_value = 0
338
+
339
+ hidden_states = torch.nn.functional.pad(
340
+ hidden_states.transpose(-1, -2),
341
+ pad=(s, s),
342
+ value=pad_value
343
+ ).transpose(-1, -2)
344
+
345
+ # Make blocks
346
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
347
+
348
+ # Fix case where block_size == sparsify_factor
349
+ if odd_offset:
350
+ hidden_states = hidden_states[..., :-1, :, :]
351
+
352
+ # Indexes for selection
353
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
354
+ s = self.sparse_block_size
355
+
356
+ # Skip right block if causal
357
+ if self.is_causal:
358
+ return hidden_states[..., u-s:u, :]
359
+
360
+ u_ = u + odd_offset
361
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
362
+
363
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
364
+
365
+ n, h, b, t, d = x_local.size()
366
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
367
+ if x_sparse is not None:
368
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
369
+ return torch.cat([x_global, x_local], dim=dim)
370
+
371
+ def chunk(self, x, n_blocks):
372
+
373
+ t, d = x.size()[-2:]
374
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
375
+
376
+
377
+ class LSGCamembertEmbeddings(RobertaEmbeddings):
378
+
379
+ def __init__(self, config):
380
+ super().__init__(config)
381
+
382
+ self.num_global_tokens = config.num_global_tokens
383
+
384
+ # Hardcoded but partially trained
385
+ self.global_embeddings = nn.Embedding(512, embedding_dim=config.hidden_size, )
386
+
387
+ self.block_size = config.block_size
388
+
389
+ def forward(
390
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
391
+ ):
392
+ if position_ids is None:
393
+ if input_ids is not None:
394
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
395
+ position_ids = create_position_ids_from_input_ids(
396
+ input_ids, self.padding_idx, past_key_values_length
397
+ ).to(input_ids.device)
398
+ else:
399
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
400
+
401
+ if input_ids is not None:
402
+ input_shape = input_ids.size()
403
+ else:
404
+ input_shape = inputs_embeds.size()[:-1]
405
+
406
+ seq_length = input_shape[-1]
407
+
408
+ if token_type_ids is None:
409
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
410
+
411
+ if inputs_embeds is None:
412
+ inputs_embeds = self.word_embeddings(input_ids)
413
+ token_type_embeddings = self.token_type_embeddings(token_type_ids[:, :seq_length])
414
+
415
+ embeddings = inputs_embeds + token_type_embeddings
416
+ if self.position_embedding_type == "absolute":
417
+ position_embeddings = self.position_embeddings(position_ids[:, :seq_length])
418
+ embeddings += position_embeddings
419
+
420
+ #if self.num_global_tokens < 0:
421
+ n, t, d = embeddings.size()
422
+
423
+ # Add global_tokens
424
+ indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
425
+ global_embeddings = self.global_embeddings(indexes)
426
+ embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
427
+
428
+ embeddings = self.LayerNorm(embeddings)
429
+ embeddings = self.dropout(embeddings)
430
+ return embeddings
431
+
432
+
433
+ class LSGCamembertSelfOutput(RobertaSelfOutput):
434
+
435
+ def __init__(self, config):
436
+ super().__init__(config)
437
+
438
+
439
+ class LSGAttention(RobertaAttention):
440
+
441
+ def __init__(self, config):
442
+
443
+ nn.Module.__init__(self)
444
+
445
+ self.self = LSGSelfAttention(config)
446
+ self.output = LSGCamembertSelfOutput(config)
447
+ self.pruned_heads = set()
448
+
449
+
450
+ class LSGCamembertIntermediate(RobertaIntermediate):
451
+
452
+ def __init__(self, config):
453
+ super().__init__(config)
454
+
455
+
456
+ class LSGCamembertOutput(RobertaOutput):
457
+
458
+ def __init__(self, config):
459
+ super().__init__(config)
460
+
461
+
462
+ class LSGCamembertPooler(RobertaPooler):
463
+
464
+ def __init__(self, config):
465
+ super().__init__(config)
466
+
467
+
468
+ class LSGSelfAttention(BaseSelfAttention):
469
+ '''
470
+ Compute local attention with overlapping blocs
471
+ Use global attention for tokens with highest norm
472
+ '''
473
+ def __init__(self, config):
474
+ super().__init__()
475
+
476
+ self.init_modules(config)
477
+
478
+ self.block_size = config.block_size
479
+ self.sparse_block_size = config.sparse_block_size
480
+ self.num_global_tokens = config.num_global_tokens
481
+ self.sparsity_factor = config.sparsity_factor
482
+ self.is_causal = config.is_decoder
483
+ self.is_decoder = config.is_decoder
484
+
485
+ self.attention = LSGAttentionProduct(
486
+ config,
487
+ block_size=config.block_size,
488
+ sparse_block_size=config.sparse_block_size,
489
+ sparsity_factor=self.sparsity_factor,
490
+ is_causal=self.is_causal
491
+ )
492
+
493
+ if self.is_causal:
494
+ self.causal_attention = CausalAttentionProduct(config)
495
+ self.full_attention = BaseAttentionProduct(config)
496
+
497
+ sparse_functions = {
498
+ "norm": self.get_sparse_tokens_with_norm,
499
+ "pooling": self.get_sparse_tokens_with_pooling,
500
+ "lsh": self.get_sparse_tokens_with_lsh,
501
+ "stride": self.get_sparse_tokens_with_stride,
502
+ "block_stride": self.get_sparse_tokens_with_block_stride,
503
+ }
504
+
505
+ self.sparsity_type = config.sparsity_type
506
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
507
+
508
+ if config.sparsity_type == "lsh":
509
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
510
+
511
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
512
+
513
+ if self.sparsity_factor == 1:
514
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
515
+
516
+ with torch.no_grad():
517
+
518
+ block_size = min(self.block_size, self.sparse_block_size)
519
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
520
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
521
+ key_norm = self.chunk(key_norm, block_size)
522
+
523
+ n, h, b, t, d = key_norm.size()
524
+
525
+ idx = key_norm.argsort(dim=-2)
526
+ del key_norm
527
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
528
+
529
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
530
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
531
+
532
+ d = keys.size()[-1]
533
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
534
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
535
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
536
+
537
+ return keys, values, mask
538
+
539
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
540
+
541
+ if self.sparsity_factor == 1:
542
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
543
+
544
+ keys = self.chunk(keys, self.sparsity_factor)
545
+ values = self.chunk(values, self.sparsity_factor)
546
+
547
+ n, h, b, t, d = keys.size()
548
+ mask = mask.reshape(n, 1, b, 1, t)
549
+ mask = ~mask.transpose(-1, -2).bool()
550
+
551
+ keys = keys * mask
552
+ values = values * mask
553
+
554
+ mask = mask.sum(dim=-2)
555
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
556
+ values = values.sum(dim=-2) / (mask + 1e-6)
557
+
558
+ mask = - (1. - mask.clamp(0, 1)) * 1e4
559
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
560
+
561
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
562
+
563
+ if self.sparsity_factor == 1:
564
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
565
+
566
+ n, h, t, d = keys.size()
567
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
568
+ sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
569
+ sparse_idx = sparse_idx.expand(n, h, -1, 1)
570
+
571
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
572
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
573
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
574
+
575
+ return keys, values, mask
576
+
577
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
578
+
579
+ if self.sparsity_factor == 1:
580
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
581
+
582
+ n, h, t, d = keys.size()
583
+
584
+ t, b = self.block_size, t // self.block_size
585
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
586
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
587
+ sparse_idx = (sparse_idx % t)
588
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
589
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
590
+
591
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
592
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
593
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
594
+
595
+ return keys, values, mask
596
+
597
+ def get_sparse_tokens_with_lsh(self, keys, values, mask):
598
+
599
+ if self.sparsity_factor == 1:
600
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
601
+
602
+ block_size = min(self.block_size, self.sparse_block_size)
603
+ keys = self.chunk(keys, block_size)
604
+ values = self.chunk(values, block_size)
605
+
606
+ n, h, b, t, d = keys.size()
607
+ mask = mask.reshape(n, 1, b, 1, t)
608
+ mask = ~mask.transpose(-1, -2).bool()
609
+
610
+ keys = keys * mask
611
+ values = values * mask
612
+ mask = mask.expand(-1, h, -1, -1, -1).float()
613
+
614
+ extra_factor = 1
615
+
616
+ for _ in range(self.lsh_num_pre_rounds):
617
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
618
+
619
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
620
+ keys /= mask + 1e-8
621
+ values /= mask + 1e-8
622
+
623
+ mask = -10000 * (1. - mask.clamp(0, 1))
624
+
625
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
626
+
627
+ def lsh_round(self, keys, values, mask, output_size):
628
+
629
+ with torch.no_grad():
630
+
631
+ n_hashes = output_size // 2
632
+ n, h, b, t, d = keys.size()
633
+ binary_mask = mask.clamp(0, 1)
634
+
635
+ indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
636
+ indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
637
+
638
+ n, h, b, t, d = keys.size()
639
+
640
+ x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
641
+ mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
642
+ keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
643
+ values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
644
+ mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
645
+
646
+ return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
647
+
648
+ def forward(
649
+ self,
650
+ hidden_states,
651
+ attention_mask=None,
652
+ head_mask=None,
653
+ encoder_hidden_states=None,
654
+ encoder_attention_mask=None,
655
+ past_key_value=None,
656
+ output_attentions=False,
657
+ ):
658
+
659
+ query_layer = self.query(hidden_states)
660
+
661
+ # If this is instantiated as a cross-attention module, the keys
662
+ # and values come from an encoder; the attention mask needs to be
663
+ # such that the encoder's padding tokens are not attended to.
664
+ is_cross_attention = encoder_hidden_states is not None
665
+
666
+ if is_cross_attention and past_key_value is not None:
667
+ # reuse k,v, cross_attentions
668
+ key_layer = past_key_value[0]
669
+ value_layer = past_key_value[1]
670
+ attention_mask = encoder_attention_mask
671
+ elif is_cross_attention:
672
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
673
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
674
+ attention_mask = encoder_attention_mask
675
+ elif past_key_value is not None:
676
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
677
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
678
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
679
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
680
+ else:
681
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
682
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
683
+
684
+ query_layer = self.transpose_for_scores(query_layer)
685
+
686
+ if self.is_decoder:
687
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
688
+ # Further calls to cross_attention layer can then reuse all cross-attention
689
+ # key/value_states (first "if" case)
690
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
691
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
692
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
693
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
694
+ past_key_value = (key_layer, value_layer)
695
+
696
+ if is_cross_attention:
697
+ outputs = self.cross_attention_forward(
698
+ query_layer=query_layer,
699
+ key_layer=key_layer,
700
+ value_layer=value_layer,
701
+ attention_mask=attention_mask,
702
+ output_attentions=output_attentions
703
+ )
704
+ else:
705
+ outputs = self.causal_forward(
706
+ query_layer,
707
+ key_layer,
708
+ value_layer,
709
+ attention_mask=attention_mask,
710
+ output_attentions=output_attentions,
711
+ )
712
+
713
+ outputs = outputs + ((key_layer, value_layer),)
714
+
715
+ else:
716
+ outputs = self.not_causal_forward(
717
+ query_layer,
718
+ key_layer,
719
+ value_layer,
720
+ attention_mask=attention_mask,
721
+ output_attentions=output_attentions
722
+ )
723
+
724
+ #if head_mask is not None:
725
+ # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
726
+ return outputs
727
+
728
+ def causal_forward(
729
+ self,
730
+ query_layer,
731
+ key_layer,
732
+ value_layer,
733
+ attention_mask=None,
734
+ output_attentions=False,
735
+ ):
736
+
737
+ n, h, t, d = key_layer.size()
738
+
739
+ # Cat global mask
740
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
741
+
742
+ # Split input into global tokens and other tokens
743
+ split = (self.num_global_tokens, t - self.num_global_tokens)
744
+ global_query, query_layer = query_layer.split(split, dim=-2)
745
+
746
+ # Use normal causal attention if local attention covers every tokens
747
+ if t <= 2 * self.block_size + self.num_global_tokens:
748
+ context_layer = self.causal_attention(
749
+ query_layer=query_layer,
750
+ key_layer=key_layer,
751
+ value_layer=value_layer,
752
+ attention_mask=attention_mask,
753
+ causal_shape=(t - self.num_global_tokens, t - self.num_global_tokens)
754
+ )
755
+
756
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
757
+ return (self.reshape_output(context_layer), )
758
+
759
+ # Split K Q M on global and non global
760
+ global_key, key_layer = key_layer.split(split, dim=-2)
761
+ global_value, value_layer = value_layer.split(split, dim=-2)
762
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
763
+
764
+ n, h, t, d = key_layer.size()
765
+
766
+ # Get sparse idx
767
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
768
+ if self.sparse_block_size and self.sparsity_factor > 0:
769
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
770
+
771
+ # Expand masks on heads
772
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
773
+ global_mask = global_mask.expand(-1, h, -1, -1)
774
+
775
+ # Compute dot product attention
776
+ context_layer = self.attention(
777
+ query_layer,
778
+ key_layer,
779
+ value_layer,
780
+ attention_mask,
781
+ sparse_key=sparse_key,
782
+ sparse_value=sparse_value,
783
+ sparse_mask=sparse_mask,
784
+ global_key=global_key,
785
+ global_value=global_value,
786
+ global_mask=global_mask
787
+ )
788
+
789
+ # Merge pseudo global (causal) and local-sparse tokens
790
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
791
+ context_layer = self.reshape_output(context_layer)
792
+
793
+ return (context_layer,)
794
+
795
+ def not_causal_forward(
796
+ self,
797
+ query_layer,
798
+ key_layer,
799
+ value_layer,
800
+ attention_mask=None,
801
+ output_attentions=False,
802
+ ):
803
+
804
+ n, h, t, d = query_layer.size()
805
+
806
+ # Cat global mask
807
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
808
+
809
+ # Use normal attention if local attention covers every tokens
810
+ if t <= 2 * self.block_size + self.num_global_tokens:
811
+ context_layer = self.full_attention(
812
+ query_layer=query_layer,
813
+ key_layer=key_layer,
814
+ value_layer=value_layer,
815
+ attention_mask=attention_mask
816
+ )
817
+ return (self.reshape_output(context_layer), )
818
+
819
+ # Split input into global tokens and other tokens
820
+ split = (self.num_global_tokens, t - self.num_global_tokens)
821
+ global_query, query_layer = query_layer.split(split, dim=-2)
822
+
823
+ # Get global_attention
824
+ bos = self.full_attention(
825
+ query_layer=global_query,
826
+ key_layer=key_layer,
827
+ value_layer=value_layer,
828
+ attention_mask=attention_mask
829
+ )
830
+
831
+ # Split K Q M on global and non global
832
+ global_key, key_layer = key_layer.split(split, dim=-2)
833
+ global_value, value_layer = value_layer.split(split, dim=-2)
834
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
835
+
836
+ n, h, t, d = key_layer.size()
837
+
838
+ # Get sparse idx
839
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
840
+
841
+ if self.sparse_block_size and self.sparsity_factor > 0:
842
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
843
+
844
+ # Expand masks on heads
845
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
846
+ global_mask = global_mask.expand(-1, h, -1, -1)
847
+
848
+ # Compute dot product attention
849
+ context_layer = self.attention(
850
+ query_layer,
851
+ key_layer,
852
+ value_layer,
853
+ attention_mask,
854
+ sparse_key=sparse_key,
855
+ sparse_value=sparse_value,
856
+ sparse_mask=sparse_mask,
857
+ global_key=global_key,
858
+ global_value=global_value,
859
+ global_mask=global_mask
860
+ )
861
+
862
+ # Merge global and local-sparse tokens
863
+ context_layer = torch.cat([bos, context_layer], dim=-2)
864
+ context_layer = self.reshape_output(context_layer)
865
+
866
+ return (context_layer,)
867
+
868
+ def cross_attention_forward(
869
+ self,
870
+ query_layer,
871
+ key_layer,
872
+ value_layer,
873
+ attention_mask=None,
874
+ output_attentions=False,
875
+ ):
876
+
877
+ context_layer = self.full_attention(
878
+ query_layer=query_layer,
879
+ key_layer=key_layer,
880
+ value_layer=value_layer,
881
+ attention_mask=attention_mask
882
+ )
883
+ return (self.reshape_output(context_layer), )
884
+
885
+ def chunk(self, x, chunk_size):
886
+
887
+ n, h, t, d = x.size()
888
+ return x.reshape(n, h, -1, chunk_size, d)
889
+
890
+
891
+ class LSGCamembertLayer(RobertaLayer):
892
+
893
+ def __init__(self, config):
894
+
895
+ nn.Module.__init__(self)
896
+
897
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
898
+ self.seq_len_dim = 1
899
+ self.attention = LSGAttention(config)
900
+ self.is_decoder = config.is_decoder
901
+ self.add_cross_attention = config.add_cross_attention
902
+ if self.add_cross_attention:
903
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
904
+ self.crossattention = LSGAttention(config)
905
+ self.intermediate = LSGCamembertIntermediate(config)
906
+ self.output = LSGCamembertOutput(config)
907
+
908
+
909
+ class LSGCamembertEncoder(RobertaEncoder):
910
+
911
+ def __init__(self, config):
912
+
913
+ nn.Module.__init__(self)
914
+
915
+ self.config = config
916
+ self.layer = nn.ModuleList([LSGCamembertLayer(config) for _ in range(config.num_hidden_layers)])
917
+ self.gradient_checkpointing = False
918
+
919
+
920
+ class LSGCamembertPreTrainedModel(RobertaPreTrainedModel):
921
+ """
922
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
923
+ models.
924
+ """
925
+
926
+ config_class = LSGCamembertConfig
927
+
928
+ def _set_gradient_checkpointing(self, module, value=False):
929
+ if isinstance(module, (RobertaEncoder, LSGCamembertEncoder)):
930
+ module.gradient_checkpointing = value
931
+
932
+
933
+ class LSGCamembertModel(LSGCamembertPreTrainedModel, RobertaModel):
934
+ """
935
+ This class overrides :class:`~transformers.CamembertModel`. Please check the superclass for the appropriate
936
+ documentation alongside usage examples.
937
+ """
938
+
939
+ config_class = LSGCamembertConfig
940
+
941
+
942
+ def __init__(self, config, add_pooling_layer=False):
943
+
944
+ LSGCamembertPreTrainedModel.__init__(self, config)
945
+
946
+ assert hasattr(config, "num_global_tokens")
947
+ self.num_global_tokens = config.num_global_tokens
948
+ self.pad_idx = config.pad_token_id
949
+
950
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
951
+ self.block_size = config.block_size
952
+ self.adaptive = config.adaptive
953
+ self.pool_with_global = config.pool_with_global
954
+
955
+ self.embeddings = LSGCamembertEmbeddings(config)
956
+ self.encoder = LSGCamembertEncoder(config)
957
+ self.pooler = LSGCamembertPooler(config) if add_pooling_layer else None
958
+
959
+ if config.add_cross_attention:
960
+ logger.warning(
961
+ "Cross attention is computed using full attention since it is not LSG compatible."
962
+ )
963
+
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ def forward(
968
+ self,
969
+ input_ids=None,
970
+ attention_mask=None,
971
+ token_type_ids=None,
972
+ position_ids=None,
973
+ head_mask=None,
974
+ inputs_embeds=None,
975
+ encoder_hidden_states=None,
976
+ encoder_attention_mask=None,
977
+ past_key_values=None,
978
+ use_cache=None,
979
+ output_attentions=None,
980
+ output_hidden_states=None,
981
+ return_dict=None
982
+ ):
983
+
984
+ inputs_ = input_ids if input_ids is not None else inputs_embeds
985
+ n, t = inputs_.size()[:2]
986
+
987
+ if attention_mask is None:
988
+ attention_mask = torch.ones(n, t, device=inputs_.device)
989
+
990
+ b = self.block_size * 2
991
+ pad = t % self.block_size
992
+
993
+ # Check if t is multiple of block_size and pad
994
+ if self.adaptive and t > b and pad > 0:
995
+ pad_length = self.block_size - pad
996
+ if input_ids is not None:
997
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
998
+ else:
999
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1000
+
1001
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1002
+
1003
+ if token_type_ids is not None:
1004
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1005
+ if position_ids is not None:
1006
+ position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1007
+
1008
+ n, t_ = attention_mask.size()
1009
+
1010
+ encoder_outputs = super().forward(
1011
+ input_ids=input_ids,
1012
+ attention_mask=attention_mask,
1013
+ token_type_ids=token_type_ids,
1014
+ position_ids=position_ids,
1015
+ head_mask=head_mask,
1016
+ inputs_embeds=inputs_embeds,
1017
+ encoder_hidden_states=encoder_hidden_states,
1018
+ encoder_attention_mask=encoder_attention_mask,
1019
+ past_key_values=past_key_values,
1020
+ use_cache=use_cache,
1021
+ output_attentions=output_attentions,
1022
+ output_hidden_states=output_hidden_states,
1023
+ return_dict=return_dict
1024
+ )
1025
+
1026
+ context = encoder_outputs[0]
1027
+ if self.pool_with_global:
1028
+ context[:, self.num_global_tokens] = context[:, 0]
1029
+
1030
+ diff = t - t_
1031
+ n, _, d = context.size()
1032
+ context = context[..., self.num_global_tokens:, :]
1033
+
1034
+ # Adapt sequence to initial shape
1035
+ if diff < 0:
1036
+ context = context[:, :t]
1037
+
1038
+ encoder_outputs.last_hidden_state = context
1039
+ sequence_output = encoder_outputs[0]
1040
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1041
+
1042
+ if not return_dict:
1043
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1044
+
1045
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1046
+ last_hidden_state=sequence_output,
1047
+ pooler_output=pooled_output,
1048
+ past_key_values=encoder_outputs.past_key_values,
1049
+ hidden_states=encoder_outputs.hidden_states,
1050
+ attentions=encoder_outputs.attentions,
1051
+ cross_attentions=encoder_outputs.cross_attentions,
1052
+ )
1053
+
1054
+ def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1055
+
1056
+ # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
1057
+ if attention_mask.dim() == 3:
1058
+ extended_attention_mask = attention_mask[:, None, :, :]
1059
+ elif attention_mask.dim() == 2:
1060
+ extended_attention_mask = attention_mask[:, None, None, :]
1061
+ else:
1062
+ raise ValueError(
1063
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
1064
+ )
1065
+
1066
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1067
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1068
+
1069
+ return extended_attention_mask
1070
+
1071
+
1072
+ class LSGCamembertForCausalLM(LSGCamembertPreTrainedModel, RobertaForCausalLM):
1073
+
1074
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1075
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1076
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1077
+
1078
+ def __init__(self, config):
1079
+
1080
+ LSGCamembertPreTrainedModel.__init__(self, config)
1081
+
1082
+ if not config.is_decoder:
1083
+ logger.warning("If you want to use `LSGCamembertLMHeadModel` as a standalone, add `is_decoder=True.`")
1084
+
1085
+ self.roberta = LSGCamembertModel(config, add_pooling_layer=False)
1086
+ self.lm_head = LSGCamembertLMHead(config)
1087
+
1088
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1089
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1090
+
1091
+ # Initialize weights and apply final processing
1092
+ self.post_init()
1093
+
1094
+
1095
+ class LSGCamembertForMaskedLM(LSGCamembertPreTrainedModel, RobertaForMaskedLM):
1096
+ """
1097
+ This class overrides :class:`~transformers.CamembertForMaskedLM`. Please check the superclass for the appropriate
1098
+ documentation alongside usage examples.
1099
+ """
1100
+
1101
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1102
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1103
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1104
+
1105
+ def __init__(self, config):
1106
+
1107
+ LSGCamembertPreTrainedModel.__init__(self, config)
1108
+
1109
+ if config.is_decoder:
1110
+ logger.warning(
1111
+ "If you want to use `LSGCamembertForMaskedLM` make sure `config.is_decoder=False` for "
1112
+ "bi-directional self-attention."
1113
+ )
1114
+
1115
+ self.roberta = LSGCamembertModel(config, add_pooling_layer=False)
1116
+ self.lm_head = LSGCamembertLMHead(config)
1117
+
1118
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1119
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1120
+
1121
+ # Initialize weights and apply final processing
1122
+ self.post_init()
1123
+
1124
+
1125
+ class LSGCamembertLMHead(RobertaLMHead):
1126
+ """LSG Head for masked language modeling."""
1127
+
1128
+ def __init__(self, config):
1129
+ super().__init__(config)
1130
+
1131
+
1132
+ class LSGCamembertForSequenceClassification(LSGCamembertPreTrainedModel, RobertaForSequenceClassification):
1133
+ """
1134
+ This class overrides :class:`~transformers.CamembertForSequenceClassification`. Please check the superclass for the
1135
+ appropriate documentation alongside usage examples.
1136
+ """
1137
+
1138
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1139
+
1140
+ def __init__(self, config):
1141
+
1142
+ LSGCamembertPreTrainedModel.__init__(self, config)
1143
+
1144
+ self.num_labels = config.num_labels
1145
+ self.config = config
1146
+
1147
+ self.roberta = LSGCamembertModel(config, add_pooling_layer=False)
1148
+ self.classifier = LSGCamembertClassificationHead(config)
1149
+
1150
+ # Initialize weights and apply final processing
1151
+ self.post_init()
1152
+
1153
+
1154
+ class LSGCamembertClassificationHead(RobertaClassificationHead):
1155
+ """Head for sentence-level classification tasks."""
1156
+
1157
+ def __init__(self, config):
1158
+ super().__init__(config)
1159
+
1160
+
1161
+ class LSGCamembertForMultipleChoice(LSGCamembertPreTrainedModel, RobertaForMultipleChoice):
1162
+ """
1163
+ This class overrides :class:`~transformers.CamembertForMultipleChoice`. Please check the superclass for the
1164
+ appropriate documentation alongside usage examples.
1165
+ """
1166
+
1167
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1168
+
1169
+ def __init__(self, config):
1170
+
1171
+ LSGCamembertPreTrainedModel.__init__(self, config)
1172
+
1173
+ self.roberta = LSGCamembertModel(config)
1174
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1175
+ self.classifier = nn.Linear(config.hidden_size, 1)
1176
+
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+
1181
+ class LSGCamembertForTokenClassification(LSGCamembertPreTrainedModel, RobertaForTokenClassification):
1182
+ """
1183
+ This class overrides :class:`~transformers.CamembertForTokenClassification`. Please check the superclass for the
1184
+ appropriate documentation alongside usage examples.
1185
+ """
1186
+
1187
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1188
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1189
+
1190
+ def __init__(self, config):
1191
+
1192
+ LSGCamembertPreTrainedModel.__init__(self, config)
1193
+
1194
+ self.num_labels = config.num_labels
1195
+
1196
+ self.roberta = LSGCamembertModel(config, add_pooling_layer=False)
1197
+ classifier_dropout = (
1198
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1199
+ )
1200
+ self.dropout = nn.Dropout(classifier_dropout)
1201
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1202
+
1203
+ # Initialize weights and apply final processing
1204
+ self.post_init()
1205
+
1206
+
1207
+ class LSGCamembertForQuestionAnswering(LSGCamembertPreTrainedModel, RobertaForQuestionAnswering):
1208
+ """
1209
+ This class overrides :class:`~transformers.CamembertForQuestionAnswering`. Please check the superclass for the
1210
+ appropriate documentation alongside usage examples.
1211
+ """
1212
+
1213
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1214
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1215
+
1216
+ def __init__(self, config):
1217
+
1218
+ LSGCamembertPreTrainedModel.__init__(self, config)
1219
+
1220
+ self.num_labels = config.num_labels
1221
+
1222
+ self.roberta = LSGCamembertModel(config, add_pooling_layer=False)
1223
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1224
+
1225
+ # Initialize weights and apply final processing
1226
+ self.post_init()
1227
+
1228
+
1229
+ def str_to_class(classname):
1230
+ return getattr(sys.modules[__name__], classname)
1231
+
1232
+ # Register model in Auto API
1233
+ try:
1234
+ LSGCamembertConfig.register_for_auto_class()
1235
+ for key, value in AUTO_MAP.items():
1236
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1237
+ except:
1238
+ warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1239
+ warn("Update to transformers >= 4.17.0 to fix.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:535f936be5fec968bb4ee2a475a7c2d05674ea7d5b1a166351fa3aec4b679c01
3
+ size 285162729
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:988bc5a00281c6d210a5d34bd143d0363741a432fefe741bf71e61b1869d4314
3
+ size 810912
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}, "additional_special_tokens": ["<s>NOTUSED", "</s>NOTUSED"]}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "additional_special_tokens": ["<s>NOTUSED", "</s>NOTUSED"], "model_max_length": 4096, "special_tokens_map_file": null, "name_or_path": "cmarkea/distilcamembert-base", "sp_model_kwargs": {}, "tokenizer_class": "CamembertTokenizer"}