Update modelling_tpu_gemma2.py
Browse files- modelling_tpu_gemma2.py +2 -1
modelling_tpu_gemma2.py
CHANGED
@@ -1141,6 +1141,7 @@ class TPUGemma2ForCausalLM(TPUGemma2PreTrainedModel, GenerationMixin):
|
|
1141 |
):
|
1142 |
# Overwritten: has a special cache type, `HybridCache`
|
1143 |
|
|
|
1144 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1145 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1146 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
@@ -1205,7 +1206,7 @@ class TPUGemma2ForCausalLM(TPUGemma2PreTrainedModel, GenerationMixin):
|
|
1205 |
)
|
1206 |
|
1207 |
if self.config.expand_input_ids:
|
1208 |
-
model_inputs["past_input_ids"] =
|
1209 |
|
1210 |
return model_inputs
|
1211 |
|
|
|
1141 |
):
|
1142 |
# Overwritten: has a special cache type, `HybridCache`
|
1143 |
|
1144 |
+
past_input_ids = input_ids
|
1145 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1146 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1147 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
|
|
1206 |
)
|
1207 |
|
1208 |
if self.config.expand_input_ids:
|
1209 |
+
model_inputs["past_input_ids"] = past_input_ids
|
1210 |
|
1211 |
return model_inputs
|
1212 |
|