Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from transformers import AutoConfig | |
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.") | |
args = parser.parse_args() | |
text_model = "google-t5/t5-small" | |
encodec_version = "facebook/encodec_24khz" | |
t5 = AutoConfig.from_pretrained(text_model) | |
encodec = AutoConfig.from_pretrained(encodec_version) | |
encodec_vocab_size = encodec.codebook_size | |
num_codebooks = 8 | |
print("num_codebooks", num_codebooks) | |
decoder_config = ParlerTTSDecoderConfig( | |
vocab_size=encodec_vocab_size + 1, | |
max_position_embeddings=2048, | |
num_hidden_layers=4, | |
ffn_dim=512, | |
num_attention_heads=8, | |
layerdrop=0.0, | |
use_cache=True, | |
activation_function="gelu", | |
hidden_size=512, | |
dropout=0.0, | |
attention_dropout=0.0, | |
activation_dropout=0.0, | |
pad_token_id=encodec_vocab_size, | |
eos_token_id=encodec_vocab_size, | |
bos_token_id=encodec_vocab_size + 1, | |
num_codebooks=num_codebooks, | |
) | |
decoder = ParlerTTSForCausalLM(decoder_config) | |
decoder.save_pretrained(os.path.join(args.save_directory, "decoder")) | |
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( | |
text_encoder_pretrained_model_name_or_path=text_model, | |
audio_encoder_pretrained_model_name_or_path=encodec_version, | |
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"), | |
vocab_size=t5.vocab_size, | |
) | |
# set the appropriate bos/pad token ids | |
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1 | |
model.generation_config.pad_token_id = encodec_vocab_size | |
model.generation_config.eos_token_id = encodec_vocab_size | |
# set other default generation config params | |
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) | |
model.generation_config.do_sample = True # True | |
model.config.pad_token_id = encodec_vocab_size | |
model.config.decoder_start_token_id = encodec_vocab_size + 1 | |
model.save_pretrained(os.path.join(args.save_directory, "tiny-model")) | |