kyleliang commited on
Commit
76f5e06
·
verified ·
1 Parent(s): d60851f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +29 -94
README.md CHANGED
@@ -6,7 +6,7 @@ tags:
6
 
7
  # LagKV Cache
8
 
9
- #### Introduction
10
 
11
  ![LagKV Cache diagram from the original paper](https://arxiv.org/html/2504.04704v1/x1.png)
12
 
@@ -18,99 +18,34 @@ Details are in the following work:
18
 
19
  [LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important](https://arxiv.org/abs/2504.04704)
20
 
21
- #### How to Use
22
 
23
- LagKV implements the Cache interface from transformers. It's easy to be integrated into the model calling function.
24
 
25
- ```python
26
- from lag_kv import LagKV
27
  from transformers import AutoModelForCausalLM, AutoTokenizer
28
-
29
- model_path = "Qwen2.5-7B-Instruct"
30
- device = "cuda:0"
31
- tokenizer = AutoTokenizer.from_pretrained(model_path)
32
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", attn_implementation="sdpa").to(device)
33
-
34
- prompt = "long text"
35
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
36
- past_key_values = LagKV(lag_size=64)
37
- print(model.generate(input_ids, past_key_values=past_key_values))
38
- # check KV cache size
39
- print(past_key_values[0][0].size())
40
- ```
41
-
42
- To compress the KV cache during the prefill stage instead of it's precisely calculated, you have to use the following inference function(for batch_size=1 only.):
43
-
44
- ```python
45
- def inference_by_prefill_compress(model, tokenizer, inputs, max_new_tokens=256, decode=False, past_key_values=None, device="cuda"):
46
- if isinstance(inputs, str):
47
- input_ids = tokenizer([inputs], return_tensors="pt")["input_ids"].to(device)
48
- else:
49
- input_ids = inputs
50
- if past_key_values is None:
51
- past_key_values = LagKV(ratio=0.2,
52
- lag_size=128,
53
- layer_idx_skip_first=[],
54
- use_then_compress=True)
55
-
56
- with torch.no_grad():
57
- sink_size = past_key_values.sink_size
58
- lag_size = past_key_values.lag_size
59
- trigger_len = sink_size + 2*lag_size
60
- input_length = input_ids.shape[1]
61
- # print(input_length > trigger_len)
62
- if input_length > trigger_len:
63
- start_idx = 0
64
- end_idx = trigger_len
65
- position_ids = torch.arange(input_length + max_new_tokens).unsqueeze(0).to(device)
66
- def batch_input():
67
- sel_input_ids = input_ids[:, start_idx:end_idx]
68
- q_len = end_idx - start_idx
69
- k_len = past_key_values.get_seq_length() + q_len
70
- batch_size = input_ids.shape[0]
71
- head_num = model.config.num_attention_heads
72
- attn_mask = torch.ones((k_len, q_len),
73
- device=input_ids.device, dtype=torch.bool)
74
- attn_mask = torch.triu(attn_mask, diagonal=1).T
75
- attn_mask = torch.flip(attn_mask, (0, 1))
76
- attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
77
- attn_mask = attn_mask.expand(batch_size, -1, -1, -1).expand(-1, head_num, -1, -1)
78
- attention_mask = torch.zeros((batch_size, head_num, q_len, k_len), device=input_ids.device, dtype=torch.bfloat16)
79
- attention_mask.masked_fill_(attn_mask, -torch.inf)
80
- return {"input_ids": sel_input_ids, "attention_mask": attention_mask}
81
-
82
- while start_idx < input_length:
83
- tmp_pos = position_ids[:, start_idx:end_idx]
84
- outputs = model(**batch_input(),
85
- past_key_values=past_key_values,
86
- position_ids=tmp_pos,
87
- cache_position=tmp_pos[0]
88
- )
89
- start_idx = end_idx
90
- end_idx += lag_size
91
- end_idx = min(end_idx, input_length)
92
-
93
- new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
94
- # print(new_token_id)
95
- new_token_count = 1
96
- generated_ids = [new_token_id]
97
- while new_token_id[0][0] != tokenizer.eos_token_id and new_token_count < max_new_tokens+1:
98
- tmp_pos = position_ids[:, (input_length+new_token_count-1):(input_length+new_token_count)]
99
- outputs = model(new_token_id,
100
- past_key_values=past_key_values,
101
- position_ids=tmp_pos,
102
- cache_position=tmp_pos[0]
103
- )
104
- new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
105
- new_token_count += 1
106
- generated_ids.append(new_token_id)
107
- generated_ids = torch.cat(generated_ids, dim=-1)
108
- else:
109
- generated_ids = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens, past_key_values=past_key_values)
110
- generated_ids = generated_ids[:, input_length:]
111
- if decode:
112
- output = tokenizer.batch_decode(generated_ids)
113
- else:
114
- output = generated_ids
115
- return output, past_key_values
116
- ```
 
6
 
7
  # LagKV Cache
8
 
9
+ ## Introduction
10
 
11
  ![LagKV Cache diagram from the original paper](https://arxiv.org/html/2504.04704v1/x1.png)
12
 
 
18
 
19
  [LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important](https://arxiv.org/abs/2504.04704)
20
 
21
+ ## Example usage
22
 
23
+ We can use the custom generation method in this repository like the the base `generate` from `transformers`:
24
 
25
+ ```py
26
+ # requires `transformers>=4.52.0`
27
  from transformers import AutoModelForCausalLM, AutoTokenizer
28
+ # Preparing model, tokenizer, and model inputs
29
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
30
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
31
+ messages = [{"role": "user", "content": "Tell me a story about a cat."}]
32
+ text = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True,
36
+ enable_thinking=False
37
+ )
38
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
39
+ # Using lagkv cache
40
+ gen_out = model.generate(
41
+ # usual `generate` arguments
42
+ **model_inputs,
43
+ do_sample=False,
44
+ max_new_tokens=100,
45
+ return_dict_in_generate=True,
46
+ # lagkv cache arguments (default `lag_ratio=0.5,lag_size=128,lag_sink_size=16`)
47
+ custom_generate="CMB-AI-LAB/lagkv_cache",
48
+ trust_remote_code=True,
49
+ )
50
+ print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
51
+ assert "lagkvcache" in str(type(gen_out.past_key_values)).lower()