from transformers import Gemma3nAudioEncoder, Gemma3nConfig from transformers import AutoFeatureExtractor, PreTrainedModel from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder class Audio(PreTrainedModel): def __init__(self, config): super().__init__(config) self.audio_tower = Gemma3nAudioEncoder(config.audio_config) self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) class GemmaAudio(PreTrainedModel): config_class = Gemma3nConfig def __init__(self, config): super().__init__(config) self.model = Audio(config) def forward(self, input_features, input_features_mask, **kwargs): output = self.model.audio_tower( input_features, ~input_features_mask, ) project = self.model.embed_audio(inputs_embeds = output[0]) return project, output[1]