|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers.cache_utils import Cache, HybridCache |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
SequenceClassifierOutputWithPast, |
|
) |
|
from transformers.models.llama.configuration_llama import LlamaConfig |
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaForSequenceClassification, |
|
LlamaModel, |
|
LlamaPreTrainedModel, |
|
) |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class LlamaBidirectionalConfig(LlamaConfig): |
|
model_type = "llama_bidirec" |
|
|
|
def __init__( |
|
self, pooling="avg", temperature=1.0, **kwargs, |
|
): |
|
self.pooling = pooling |
|
self.temperature = temperature |
|
super().__init__(**kwargs,) |
|
|
|
class LlamaBidirectionalModel(LlamaModel): |
|
config_class = LlamaBidirectionalConfig |
|
|
|
def __init__(self, config: LlamaConfig): |
|
super().__init__(config) |
|
for layer in self.layers: |
|
layer.self_attn.is_causal = False |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool, |
|
): |
|
|
|
assert self.config._attn_implementation in ["flash_attention_2", "eager"], f"Unsupported attention implementation: {self.config._attn_implementation}, only support flash_attention_2 or eager" |
|
calculated_attention_mask = super()._update_causal_mask( |
|
attention_mask, |
|
input_tensor, |
|
cache_position, |
|
past_key_values, |
|
output_attentions) |
|
if calculated_attention_mask is None: |
|
return None |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and (attention_mask == 0.0).any(): |
|
return attention_mask |
|
return None |
|
elif self.config._attn_implementation == "eager": |
|
causal_mask = _prepare_4d_attention_mask( |
|
attention_mask, |
|
dtype=input_tensor.dtype, |
|
) |
|
return causal_mask |