make xformers an optional dependency
Browse filesThis adapts the LlamaMLP from the llama modeling code in transformers to handle splitting the w12 weight during the forward pass, and uses it in case xformers is not available on the system.
This enables the model to be used on MacOS for example.
model.py
CHANGED
@@ -9,7 +9,11 @@ from torch.nn.functional import scaled_dot_product_attention
|
|
9 |
from typing import Optional
|
10 |
import numpy as np
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
try:
|
15 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
@@ -100,6 +104,21 @@ class NeoBERTConfig(PretrainedConfig):
|
|
100 |
self.max_length = max_length
|
101 |
self.kwargs = kwargs
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
class EncoderBlock(nn.Module):
|
105 |
"""Transformer encoder block."""
|
@@ -117,7 +136,10 @@ class EncoderBlock(nn.Module):
|
|
117 |
multiple_of = 8
|
118 |
intermediate_size = int(2 * config.intermediate_size / 3)
|
119 |
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
120 |
-
|
|
|
|
|
|
|
121 |
|
122 |
# Layer norms
|
123 |
self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
|
|
|
9 |
from typing import Optional
|
10 |
import numpy as np
|
11 |
|
12 |
+
try:
|
13 |
+
from xformers.ops import SwiGLU
|
14 |
+
XFORMERS_AVAILABLE = True
|
15 |
+
except ImportError:
|
16 |
+
XFORMERS_AVAILABLE = False
|
17 |
|
18 |
try:
|
19 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
|
104 |
self.max_length = max_length
|
105 |
self.kwargs = kwargs
|
106 |
|
107 |
+
# Adapted from transformers.models.llama.modeling_llama.LlamaMLP
|
108 |
+
class NeobertMLP(nn.Module):
|
109 |
+
def __init__(self, hidden_size, intermediate_size, bias=False):
|
110 |
+
super().__init__()
|
111 |
+
self.hidden_size = hidden_size
|
112 |
+
self.intermediate_size = intermediate_size
|
113 |
+
self.w12 = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=bias)
|
114 |
+
self.w3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
|
115 |
+
self.act_fn = nn.SiLU()
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
w1, w2 = self.w12(x).chunk(2, dim=-1)
|
119 |
+
w3 = self.w3(self.act_fn(w1) * w2)
|
120 |
+
return w3
|
121 |
+
|
122 |
|
123 |
class EncoderBlock(nn.Module):
|
124 |
"""Transformer encoder block."""
|
|
|
136 |
multiple_of = 8
|
137 |
intermediate_size = int(2 * config.intermediate_size / 3)
|
138 |
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
139 |
+
if XFORMERS_AVAILABLE:
|
140 |
+
self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
|
141 |
+
else:
|
142 |
+
self.ffn = NeobertMLP(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
|
143 |
|
144 |
# Layer norms
|
145 |
self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
|