import argparse import json import logging import os from typing import Any, Dict, Optional import torch from huggingface_hub import hf_hub_download, upload_folder from safetensors.torch import load_file, save_file from transformers.models.blt_wip.configuration_blt import BLTConfig from transformers.models.blt_wip.modeling_blt import BLTModel from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM from transformers.utils import logging as transformers_logging logger = transformers_logging.get_logger(__name__) transformers_logging.set_verbosity_info() def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: logger.info("Merging configurations") with open(config_path, "r") as f: main_config = json.load(f) with open(entropy_params_path, "r") as f: entropy_data = json.load(f) entropy_model_params = entropy_data.get("entropy_model", {}) patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) unified_config = main_config.copy()["args"] for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: if key in unified_config and not isinstance(unified_config[key], int): unified_config[key] = int(unified_config[key]) patch_size = patcher_args.get("patch_size", 8) if isinstance(patch_size, float): patch_size = int(patch_size) # Create patcher config patcher_hidden_size = int(entropy_model_params.get("dim", 512)) patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) patcher_config = { "vocab_size": int(entropy_model_params.get("vocab_size", 256)), "hidden_size": patcher_hidden_size, "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)), "num_attention_heads": int(entropy_model_params.get("n_heads", 8)), "num_key_value_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)), "norm_eps": entropy_model_params.get("norm_eps", 1e-5), "dropout": entropy_model_params.get("dropout", 0.0), "rope_theta": entropy_model_params.get("rope_theta", 10000.0), "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), "intermediate_size": patcher_intermediate_size, } # Create encoder config encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) encoder_multiple_of = unified_config.get("multiple_of", 256) encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) encoder_config = { "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), "cross_attn_k": unified_config.get("cross_attn_k", 2), "hidden_size_global": unified_config.get("hidden_size_global", 2048), "pm_size": unified_config.get("pm_size", 0), "hidden_size": encoder_hidden_size, "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), "num_key_value_heads": unified_config.get("n_kv_heads"), "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), "norm_eps": unified_config.get("norm_eps", 1e-5), "dropout": unified_config.get("dropout", 0.0), "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), "rope_theta": unified_config.get("rope_theta", 10000.0), "rope_scaling": {"rope_type": "default"}, "hidden_act": unified_config.get("hidden_act", "silu"), "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), "intermediate_size": encoder_intermediate_size, } # Create decoder config decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) decoder_multiple_of = unified_config.get("multiple_of", 256) decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) decoder_config = { "vocab_size": unified_config.get("vocab_size", 256), "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), "cross_attn_k": unified_config.get("cross_attn_k", 2), "hidden_size_global": unified_config.get("hidden_size_global", 2048), "hidden_size": decoder_hidden_size, "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), "num_key_value_heads": unified_config.get("n_kv_heads"), "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), "norm_eps": unified_config.get("norm_eps", 1e-5), "dropout": unified_config.get("dropout", 0.0), "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), "rope_theta": unified_config.get("rope_theta", 10000.0), "rope_scaling": {"rope_type": "default"}, "hidden_act": unified_config.get("hidden_act", "silu"), "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), "intermediate_size": decoder_intermediate_size, } # Create global transformer config global_hidden_size = unified_config.get("dim_global", 2048) global_multiple_of = unified_config.get("multiple_of", 256) global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) global_config = { "hidden_size": global_hidden_size, "num_attention_heads": unified_config.get("n_heads_global", 16), "num_key_value_heads": unified_config.get("n_kv_heads_global"), "num_hidden_layers": unified_config.get("n_layers_global", 25), "norm_eps": unified_config.get("norm_eps", 1e-5), "dropout": unified_config.get("dropout", 0.0), "max_position_embeddings": unified_config.get("max_seqlen", 1024), "rope_theta": unified_config.get("rope_theta", 10000.0), "rope_scaling": {"rope_type": "default"}, "hidden_act": unified_config.get("hidden_act", "silu"), "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), "intermediate_size": global_intermediate_size, } # Create main config with sub-configs main_config_dict = { "model_type": "blt", "vocab_size": unified_config.get("vocab_size", 256), "max_position_embeddings": unified_config.get("max_seqlen", 1024), "patch_in_forward": True, "realtime_patching": True, "patching_mode": "entropy", "patch_size": patch_size, "patching_threshold": patcher_args.get("threshold", 0.5), "patching_threshold_add": patcher_args.get("threshold_add", 0.0), "max_patch_length": patcher_args.get("max_patch_length"), "patching_batch_size": patcher_args.get("patching_batch_size", 1), "patching_device": patcher_args.get("patching_device", "cuda"), "monotonicity": patcher_args.get("monotonicity", False), "cross_attn_k": unified_config.get("cross_attn_k", 2), "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"), "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000), "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3), "pm_size": unified_config.get("pm_size", 0), "patcher_config": patcher_config, "encoder_config": encoder_config, "decoder_config": decoder_config, "global_config": global_config, } main_config_dict["tie_word_embeddings"] = False logger.info(f"Merged configuration with {len(main_config_dict)} parameters") return main_config_dict def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: component_mappings = { ".attention.": ".self_attn.", ".feed_forward.": ".mlp.", ".attention_norm.": ".input_layernorm.", ".ffn_norm.": ".post_attention_layernorm.", ".tok_embeddings.": ".embed_tokens.", ".cross_attn_norm_q.": ".q_norm.", ".cross_attn_norm_kv.": ".k_norm.", ".w1.": ".gate_proj.", ".w2.": ".down_proj.", ".w3.": ".up_proj.", ".wq.": ".q_proj.", ".wk.": ".k_proj.", ".wv.": ".v_proj.", ".wo.": ".o_proj.", ".output.": ".lm_head.", } new_state_dict = {} for old_key, tensor in state_dict.items(): new_key = old_key for old_pattern, new_pattern in component_mappings.items(): if old_pattern in new_key: new_key = new_key.replace(old_pattern, new_pattern) new_state_dict[new_key] = tensor return new_state_dict def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: main_weights = load_file(weights_path) entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) if "model" in entropy_weights: entropy_weights = entropy_weights["model"] elif "state_dict" in entropy_weights: entropy_weights = entropy_weights["state_dict"] unified_weights = main_weights.copy() for key, tensor in entropy_weights.items(): patcher_key = f"patcher.{key}" unified_weights[patcher_key] = tensor unified_weights = apply_weight_mapping(unified_weights) decoder_lm_head_key = "local_decoder.lm_head.weight" top_lm_head_key = "lm_head.weight" unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] del unified_weights[decoder_lm_head_key] prefixed_weights = {} for key, tensor in unified_weights.items(): if key == top_lm_head_key: prefixed_weights[key] = tensor elif not key.startswith("model."): prefixed_weights[f"model.{key}"] = tensor else: prefixed_weights[key] = tensor unified_weights = prefixed_weights return unified_weights def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): tokenizer_config = { "tokenizer_class": "BltTokenizer", "vocab_size": config.get("vocab_size", 256), "model_max_length": config.get("max_seqlen", 1024), "add_bos_token": True, "add_eos_token": True, "bos_token": "", "eos_token": "", "pad_token": "", "unk_token": "", } tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") with open(tokenizer_path, "w") as f: json.dump(tokenizer_config, f, indent=2) def push_to_hub( local_dir: str, repo_id: str, commit_message: str = "Upload converted BLT model", private: bool = False, token: Optional[str] = None, ) -> None: try: upload_folder( folder_path=local_dir, repo_id=repo_id, commit_message=commit_message, repo_type="model", token=token, ) logger.info(f"Successfully pushed model to {repo_id}") except Exception as e: logger.error(f"Failed to push model to Hub: {e}") raise def convert_hf_blt_to_unified( model_id: str, output_dir: str, config_name: str = "config.json", weights_name: str = "model.bin", cache_dir: Optional[str] = None, push_to_hub_repo: Optional[str] = None, hub_private: bool = False, hub_token: Optional[str] = None, ) -> None: # Download model files config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) entropy_weights_path = hf_hub_download( repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir ) unified_config = merge_configurations(config_path, entropy_params_path) unified_weights = merge_weights(weights_path, entropy_weights_path) os.makedirs(output_dir, exist_ok=True) config_path = os.path.join(output_dir, config_name) with open(config_path, "w") as f: json.dump(unified_config, f, indent=2) if weights_name.endswith(".bin"): weights_name = weights_name.replace(".bin", ".safetensors") weights_path = os.path.join(output_dir, weights_name) save_file(unified_weights, weights_path) create_tokenizer_config(output_dir, unified_config) logger.info(f"Conversion completed, model saved to: {output_dir}") if push_to_hub_repo: push_to_hub( local_dir=output_dir, repo_id=push_to_hub_repo, commit_message="Upload BLT model converted", private=hub_private, token=hub_token, ) def main(): parser = argparse.ArgumentParser( description="Convert BLT models from HuggingFace Hub format to unified format", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--model_id", type=str, default="facebook/blt-1b", ) parser.add_argument( "--output_dir", type=str, default="./blt_converted", ) parser.add_argument( "--config_name", type=str, default="config.json", ) parser.add_argument( "--weights_name", type=str, default="model.bin", ) parser.add_argument( "--cache_dir", type=str, default=None, ) parser.add_argument( "--debug", action="store_true", default=True, ) parser.add_argument( "--push_to_hub", type=str, default=None, ) parser.add_argument( "--hub_private", action="store_true", default=False, ) parser.add_argument( "--hub_token", type=str, default="hf_token", ) args = parser.parse_args() transformers_logging.set_verbosity_debug() logging.basicConfig(level=logging.DEBUG) try: convert_hf_blt_to_unified( model_id=args.model_id, output_dir=args.output_dir, config_name=args.config_name, weights_name=args.weights_name, cache_dir=args.cache_dir, push_to_hub_repo=args.push_to_hub, hub_private=args.hub_private, hub_token=args.hub_token, ) except Exception as e: logger.error(f"Conversion failed: {e}") raise if __name__ == "__main__": main()