benjamin commited on
Commit
274744f
·
verified ·
1 Parent(s): afc198b

Update modelling_tpu_gemma2.py

Browse files
Files changed (1) hide show
  1. modelling_tpu_gemma2.py +2 -1
modelling_tpu_gemma2.py CHANGED
@@ -1141,6 +1141,7 @@ class TPUGemma2ForCausalLM(TPUGemma2PreTrainedModel, GenerationMixin):
1141
  ):
1142
  # Overwritten: has a special cache type, `HybridCache`
1143
 
 
1144
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1145
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1146
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
@@ -1205,7 +1206,7 @@ class TPUGemma2ForCausalLM(TPUGemma2PreTrainedModel, GenerationMixin):
1205
  )
1206
 
1207
  if self.config.expand_input_ids:
1208
- model_inputs["past_input_ids"] = input_ids
1209
 
1210
  return model_inputs
1211
 
 
1141
  ):
1142
  # Overwritten: has a special cache type, `HybridCache`
1143
 
1144
+ past_input_ids = input_ids
1145
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1146
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1147
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
 
1206
  )
1207
 
1208
  if self.config.expand_input_ids:
1209
+ model_inputs["past_input_ids"] = past_input_ids
1210
 
1211
  return model_inputs
1212