|
from transformers.cache_utils import Cache |
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers.utils import logging |
|
from transformers.configuration_utils import PretrainedConfig |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class HybridCache(Cache): |
|
""" |
|
Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention |
|
and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention |
|
and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. |
|
|
|
Parameters: |
|
config (`PretrainedConfig): |
|
The configuration file defining the shape-related attributes required to initialize the static cache. |
|
batch_size (`int`): |
|
The batch size with which the model will be used. Note that a new instance must be instantiated if a |
|
smaller batch size is used. |
|
max_cache_len (`int`): |
|
The maximum sequence length with which the model will be used. |
|
device (`torch.device` or `str`, *optional*): |
|
The device on which the cache should be initialized. If you're using more than 1 computation device, you |
|
should pass the `layer_device_map` argument instead. |
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`): |
|
The default `dtype` to use when initializing the layer. |
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): |
|
Mapping between the layers and its device. This is required when you are manually initializing the cache |
|
and the model is splitted between differents gpus. You can know which layers mapped to which device by |
|
checking the associated device_map: `model.hf_device_map`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache |
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") |
|
|
|
>>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") |
|
|
|
>>> # Prepare a cache class and pass it to model's forward |
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate |
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10 |
|
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) |
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
|
>>> outputs.past_key_values # access cache filled with key/values from generation |
|
HybridCache() |
|
``` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
config: PretrainedConfig, |
|
batch_size: int = None, |
|
max_cache_len: int = None, |
|
device: Union[torch.device, str] = None, |
|
dtype: torch.dtype = torch.float32, |
|
max_batch_size: Optional[int] = None, |
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, |
|
) -> None: |
|
super().__init__() |
|
if batch_size is not None: |
|
logger.warning_once( |
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " |
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead." |
|
) |
|
if not hasattr(config, "sliding_window") or config.sliding_window is None: |
|
raise ValueError( |
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " |
|
"sliding window attention, please check if there is a `sliding_window` field in the model " |
|
"config and it's not set to None." |
|
) |
|
self.max_cache_len = max_cache_len |
|
self.max_batch_size = batch_size or max_batch_size |
|
|
|
self.head_dim = ( |
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
|
) |
|
|
|
self.dtype = dtype |
|
self.num_key_value_heads = ( |
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
|
) |
|
|
|
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 |
|
self.is_sliding = torch.tensor( |
|
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool |
|
) |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self.chunk_cache = {} |
|
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) |
|
sliding_cache_shape = ( |
|
self.max_batch_size, |
|
self.num_key_value_heads, |
|
min(config.sliding_window, max_cache_len), |
|
self.head_dim, |
|
) |
|
device = torch.device(device) if device is not None else None |
|
for i in range(config.num_hidden_layers): |
|
if layer_device_map is not None: |
|
layer_device = layer_device_map[i] |
|
else: |
|
layer_device = device |
|
|
|
|
|
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape |
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) |
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) |
|
torch._dynamo.mark_static_address(new_layer_key_cache) |
|
torch._dynamo.mark_static_address(new_layer_value_cache) |
|
self.key_cache.append(new_layer_key_cache) |
|
self.value_cache.append(new_layer_value_cache) |
|
|
|
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): |
|
if cache_position.shape[0] > max_cache_len: |
|
k_out = key_states[:, :, -max_cache_len:, :] |
|
v_out = value_states[:, :, -max_cache_len:, :] |
|
|
|
self.key_cache[layer_idx] += k_out |
|
self.value_cache[layer_idx] += v_out |
|
|
|
|
|
return key_states, value_states |
|
|
|
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) |
|
cache_position = cache_position.clamp(0, max_cache_len - 1) |
|
to_shift = cache_position >= max_cache_len - 1 |
|
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len |
|
k_out = k_out[:, :, indices] |
|
v_out = v_out[:, :, indices] |
|
|
|
k_out[:, :, cache_position] = key_states |
|
v_out[:, :, cache_position] = value_states |
|
|
|
self.key_cache[layer_idx].zero_() |
|
self.value_cache[layer_idx].zero_() |
|
|
|
self.key_cache[layer_idx] += k_out |
|
self.value_cache[layer_idx] += v_out |
|
return k_out, v_out |
|
|
|
def _static_update(self, layer_idx,cache): |
|
self.chunk_cache[layer_idx] = cache |
|
return |
|
|
|
def _get_chunk_cache(self,layer_idx): |
|
self.chunk_cache.setdefault(layer_idx,None) |
|
return self.chunk_cache[layer_idx] |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor]: |
|
cache_position = cache_kwargs.get("cache_position") |
|
sliding_window = cache_kwargs.get("sliding_window") |
|
|
|
|
|
|
|
if self.key_cache[layer_idx].device != key_states.device: |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) |
|
if self.value_cache[layer_idx].device != value_states.device: |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) |
|
|
|
k_out = self.key_cache[layer_idx] |
|
v_out = self.value_cache[layer_idx] |
|
key_states = key_states.to(k_out.dtype) |
|
value_states = value_states.to(v_out.dtype) |
|
|
|
if sliding_window: |
|
update_fn = self._sliding_update |
|
else: |
|
update_fn = self._static_update |
|
|
|
return update_fn( |
|
cache_position, |
|
layer_idx, |
|
key_states, |
|
value_states, |
|
k_out, |
|
v_out, |
|
k_out.shape[2], |
|
) |
|
|
|
def get_max_cache_shape(self) -> Optional[int]: |
|
return self.max_cache_len |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0): |
|
|
|
|
|
|
|
if layer_idx != 0: |
|
raise ValueError( |
|
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " |
|
"Using the `layer_idx` argument is not supported." |
|
) |
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
|
|
|
def reset(self): |
|
"""Resets the cache values while preserving the objects""" |
|
for layer_idx in range(len(self.key_cache)): |
|
|
|
self.key_cache[layer_idx].zero_() |
|
self.value_cache[layer_idx].zero_() |
|
|
|
@property |
|
def batch_size(self): |
|
logger.warning_once( |
|
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " |
|
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." |
|
) |
|
return self.max_batch_size |
|
|