NyxKrage commited on
Commit
237f50f
·
verified ·
1 Parent(s): a4fbc49

make xformers an optional dependency

Browse files

This adapts the LlamaMLP from the llama modeling code in transformers to handle splitting the w12 weight during the forward pass, and uses it in case xformers is not available on the system.

This enables the model to be used on MacOS for example.

Files changed (1) hide show
  1. model.py +24 -2
model.py CHANGED
@@ -9,7 +9,11 @@ from torch.nn.functional import scaled_dot_product_attention
9
  from typing import Optional
10
  import numpy as np
11
 
12
- from xformers.ops import SwiGLU
 
 
 
 
13
 
14
  try:
15
  from flash_attn.flash_attn_interface import flash_attn_varlen_func
@@ -100,6 +104,21 @@ class NeoBERTConfig(PretrainedConfig):
100
  self.max_length = max_length
101
  self.kwargs = kwargs
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class EncoderBlock(nn.Module):
105
  """Transformer encoder block."""
@@ -117,7 +136,10 @@ class EncoderBlock(nn.Module):
117
  multiple_of = 8
118
  intermediate_size = int(2 * config.intermediate_size / 3)
119
  intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
120
- self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
 
 
 
121
 
122
  # Layer norms
123
  self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
 
9
  from typing import Optional
10
  import numpy as np
11
 
12
+ try:
13
+ from xformers.ops import SwiGLU
14
+ XFORMERS_AVAILABLE = True
15
+ except ImportError:
16
+ XFORMERS_AVAILABLE = False
17
 
18
  try:
19
  from flash_attn.flash_attn_interface import flash_attn_varlen_func
 
104
  self.max_length = max_length
105
  self.kwargs = kwargs
106
 
107
+ # Adapted from transformers.models.llama.modeling_llama.LlamaMLP
108
+ class NeobertMLP(nn.Module):
109
+ def __init__(self, hidden_size, intermediate_size, bias=False):
110
+ super().__init__()
111
+ self.hidden_size = hidden_size
112
+ self.intermediate_size = intermediate_size
113
+ self.w12 = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=bias)
114
+ self.w3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
115
+ self.act_fn = nn.SiLU()
116
+
117
+ def forward(self, x):
118
+ w1, w2 = self.w12(x).chunk(2, dim=-1)
119
+ w3 = self.w3(self.act_fn(w1) * w2)
120
+ return w3
121
+
122
 
123
  class EncoderBlock(nn.Module):
124
  """Transformer encoder block."""
 
136
  multiple_of = 8
137
  intermediate_size = int(2 * config.intermediate_size / 3)
138
  intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
139
+ if XFORMERS_AVAILABLE:
140
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
141
+ else:
142
+ self.ffn = NeobertMLP(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
143
 
144
  # Layer norms
145
  self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)