blt-testing / backup_blt_modellike /convert_blt_weights_to_hf.py
itazap's picture
itazap HF Staff
Upload BLT model converted
724be6e verified
raw
history blame
15.2 kB
import argparse
import json
import logging
import os
from typing import Any, Dict, Optional
import torch
from huggingface_hub import hf_hub_download, upload_folder
from safetensors.torch import load_file, save_file
from transformers.models.blt_wip.configuration_blt import BLTConfig
from transformers.models.blt_wip.modeling_blt import BLTModel
from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
from transformers.utils import logging as transformers_logging
logger = transformers_logging.get_logger(__name__)
transformers_logging.set_verbosity_info()
def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
logger.info("Merging configurations")
with open(config_path, "r") as f:
main_config = json.load(f)
with open(entropy_params_path, "r") as f:
entropy_data = json.load(f)
entropy_model_params = entropy_data.get("entropy_model", {})
patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
unified_config = main_config.copy()["args"]
for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
if key in unified_config and not isinstance(unified_config[key], int):
unified_config[key] = int(unified_config[key])
patch_size = patcher_args.get("patch_size", 8)
if isinstance(patch_size, float):
patch_size = int(patch_size)
# Create patcher config
patcher_hidden_size = int(entropy_model_params.get("dim", 512))
patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256))
patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of)
patcher_config = {
"vocab_size": int(entropy_model_params.get("vocab_size", 256)),
"hidden_size": patcher_hidden_size,
"num_hidden_layers": int(entropy_model_params.get("n_layers", 8)),
"num_attention_heads": int(entropy_model_params.get("n_heads", 8)),
"num_key_value_heads": int(entropy_model_params.get("n_kv_heads"))
if entropy_model_params.get("n_kv_heads") is not None
else None,
"max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)),
"norm_eps": entropy_model_params.get("norm_eps", 1e-5),
"dropout": entropy_model_params.get("dropout", 0.0),
"rope_theta": entropy_model_params.get("rope_theta", 10000.0),
"attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
"attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
"intermediate_size": patcher_intermediate_size,
}
# Create encoder config
encoder_hidden_size = unified_config.get("dim_local_encoder", 1024)
encoder_multiple_of = unified_config.get("multiple_of", 256)
encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of)
encoder_config = {
"vocab_size": unified_config.get("vocab_size", 256),
"cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False),
"cross_attn_k": unified_config.get("cross_attn_k", 2),
"hidden_size_global": unified_config.get("hidden_size_global", 2048),
"pm_size": unified_config.get("pm_size", 0),
"hidden_size": encoder_hidden_size,
"num_attention_heads": unified_config.get("n_heads_local_encoder", 16),
"num_key_value_heads": unified_config.get("n_kv_heads"),
"num_hidden_layers": unified_config.get("n_layers_local_encoder", 1),
"norm_eps": unified_config.get("norm_eps", 1e-5),
"dropout": unified_config.get("dropout", 0.0),
"max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
"rope_theta": unified_config.get("rope_theta", 10000.0),
"rope_scaling": {"rope_type": "default"},
"hidden_act": unified_config.get("hidden_act", "silu"),
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
"intermediate_size": encoder_intermediate_size,
}
# Create decoder config
decoder_hidden_size = unified_config.get("dim_local_decoder", 1024)
decoder_multiple_of = unified_config.get("multiple_of", 256)
decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of)
decoder_config = {
"vocab_size": unified_config.get("vocab_size", 256),
"cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False),
"cross_attn_k": unified_config.get("cross_attn_k", 2),
"hidden_size_global": unified_config.get("hidden_size_global", 2048),
"hidden_size": decoder_hidden_size,
"num_attention_heads": unified_config.get("n_heads_local_decoder", 16),
"num_key_value_heads": unified_config.get("n_kv_heads"),
"num_hidden_layers": unified_config.get("n_layers_local_decoder", 9),
"norm_eps": unified_config.get("norm_eps", 1e-5),
"dropout": unified_config.get("dropout", 0.0),
"max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
"rope_theta": unified_config.get("rope_theta", 10000.0),
"rope_scaling": {"rope_type": "default"},
"hidden_act": unified_config.get("hidden_act", "silu"),
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
"intermediate_size": decoder_intermediate_size,
}
# Create global transformer config
global_hidden_size = unified_config.get("dim_global", 2048)
global_multiple_of = unified_config.get("multiple_of", 256)
global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of)
global_config = {
"hidden_size": global_hidden_size,
"num_attention_heads": unified_config.get("n_heads_global", 16),
"num_key_value_heads": unified_config.get("n_kv_heads_global"),
"num_hidden_layers": unified_config.get("n_layers_global", 25),
"norm_eps": unified_config.get("norm_eps", 1e-5),
"dropout": unified_config.get("dropout", 0.0),
"max_position_embeddings": unified_config.get("max_seqlen", 1024),
"rope_theta": unified_config.get("rope_theta", 10000.0),
"rope_scaling": {"rope_type": "default"},
"hidden_act": unified_config.get("hidden_act", "silu"),
"_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
"intermediate_size": global_intermediate_size,
}
# Create main config with sub-configs
main_config_dict = {
"model_type": "blt",
"vocab_size": unified_config.get("vocab_size", 256),
"max_position_embeddings": unified_config.get("max_seqlen", 1024),
"patch_in_forward": True,
"realtime_patching": True,
"patching_mode": "entropy",
"patch_size": patch_size,
"patching_threshold": patcher_args.get("threshold", 0.5),
"patching_threshold_add": patcher_args.get("threshold_add", 0.0),
"max_patch_length": patcher_args.get("max_patch_length"),
"patching_batch_size": patcher_args.get("patching_batch_size", 1),
"patching_device": patcher_args.get("patching_device", "cuda"),
"monotonicity": patcher_args.get("monotonicity", False),
"cross_attn_k": unified_config.get("cross_attn_k", 2),
"encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"),
"encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000),
"encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3),
"pm_size": unified_config.get("pm_size", 0),
"patcher_config": patcher_config,
"encoder_config": encoder_config,
"decoder_config": decoder_config,
"global_config": global_config,
}
main_config_dict["tie_word_embeddings"] = False
logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
return main_config_dict
def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
component_mappings = {
".attention.": ".self_attn.",
".feed_forward.": ".mlp.",
".attention_norm.": ".input_layernorm.",
".ffn_norm.": ".post_attention_layernorm.",
".tok_embeddings.": ".embed_tokens.",
".cross_attn_norm_q.": ".q_norm.",
".cross_attn_norm_kv.": ".k_norm.",
".w1.": ".gate_proj.",
".w2.": ".down_proj.",
".w3.": ".up_proj.",
".wq.": ".q_proj.",
".wk.": ".k_proj.",
".wv.": ".v_proj.",
".wo.": ".o_proj.",
".output.": ".lm_head.",
}
new_state_dict = {}
for old_key, tensor in state_dict.items():
new_key = old_key
for old_pattern, new_pattern in component_mappings.items():
if old_pattern in new_key:
new_key = new_key.replace(old_pattern, new_pattern)
new_state_dict[new_key] = tensor
return new_state_dict
def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
main_weights = load_file(weights_path)
entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True)
if "model" in entropy_weights:
entropy_weights = entropy_weights["model"]
elif "state_dict" in entropy_weights:
entropy_weights = entropy_weights["state_dict"]
unified_weights = main_weights.copy()
for key, tensor in entropy_weights.items():
patcher_key = f"patcher.{key}"
unified_weights[patcher_key] = tensor
unified_weights = apply_weight_mapping(unified_weights)
decoder_lm_head_key = "local_decoder.lm_head.weight"
top_lm_head_key = "lm_head.weight"
unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
del unified_weights[decoder_lm_head_key]
prefixed_weights = {}
for key, tensor in unified_weights.items():
if key == top_lm_head_key:
prefixed_weights[key] = tensor
elif not key.startswith("model."):
prefixed_weights[f"model.{key}"] = tensor
else:
prefixed_weights[key] = tensor
unified_weights = prefixed_weights
return unified_weights
def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
tokenizer_config = {
"tokenizer_class": "BltTokenizer",
"vocab_size": config.get("vocab_size", 256),
"model_max_length": config.get("max_seqlen", 1024),
"add_bos_token": True,
"add_eos_token": True,
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<pad>",
"unk_token": "<unk>",
}
tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
with open(tokenizer_path, "w") as f:
json.dump(tokenizer_config, f, indent=2)
def push_to_hub(
local_dir: str,
repo_id: str,
commit_message: str = "Upload converted BLT model",
private: bool = False,
token: Optional[str] = None,
) -> None:
try:
upload_folder(
folder_path=local_dir,
repo_id=repo_id,
commit_message=commit_message,
repo_type="model",
token=token,
)
logger.info(f"Successfully pushed model to {repo_id}")
except Exception as e:
logger.error(f"Failed to push model to Hub: {e}")
raise
def convert_hf_blt_to_unified(
model_id: str,
output_dir: str,
config_name: str = "config.json",
weights_name: str = "model.bin",
cache_dir: Optional[str] = None,
push_to_hub_repo: Optional[str] = None,
hub_private: bool = False,
hub_token: Optional[str] = None,
) -> None:
# Download model files
config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir)
entropy_weights_path = hf_hub_download(
repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir
)
unified_config = merge_configurations(config_path, entropy_params_path)
unified_weights = merge_weights(weights_path, entropy_weights_path)
os.makedirs(output_dir, exist_ok=True)
config_path = os.path.join(output_dir, config_name)
with open(config_path, "w") as f:
json.dump(unified_config, f, indent=2)
if weights_name.endswith(".bin"):
weights_name = weights_name.replace(".bin", ".safetensors")
weights_path = os.path.join(output_dir, weights_name)
save_file(unified_weights, weights_path)
create_tokenizer_config(output_dir, unified_config)
logger.info(f"Conversion completed, model saved to: {output_dir}")
if push_to_hub_repo:
push_to_hub(
local_dir=output_dir,
repo_id=push_to_hub_repo,
commit_message="Upload BLT model converted",
private=hub_private,
token=hub_token,
)
def main():
parser = argparse.ArgumentParser(
description="Convert BLT models from HuggingFace Hub format to unified format",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--model_id",
type=str,
default="facebook/blt-1b",
)
parser.add_argument(
"--output_dir",
type=str,
default="./blt_converted",
)
parser.add_argument(
"--config_name",
type=str,
default="config.json",
)
parser.add_argument(
"--weights_name",
type=str,
default="model.bin",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
)
parser.add_argument(
"--debug",
action="store_true",
default=True,
)
parser.add_argument(
"--push_to_hub",
type=str,
default=None,
)
parser.add_argument(
"--hub_private",
action="store_true",
default=False,
)
parser.add_argument(
"--hub_token",
type=str,
default="hf_token",
)
args = parser.parse_args()
transformers_logging.set_verbosity_debug()
logging.basicConfig(level=logging.DEBUG)
try:
convert_hf_blt_to_unified(
model_id=args.model_id,
output_dir=args.output_dir,
config_name=args.config_name,
weights_name=args.weights_name,
cache_dir=args.cache_dir,
push_to_hub_repo=args.push_to_hub,
hub_private=args.hub_private,
hub_token=args.hub_token,
)
except Exception as e:
logger.error(f"Conversion failed: {e}")
raise
if __name__ == "__main__":
main()