|
""" |
|
python create_lmdb_14b_shards.py \ |
|
--data_path /mnt/localssd/wanx_14b_data \ |
|
--lmdb_path /mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb |
|
""" |
|
from tqdm import tqdm |
|
import numpy as np |
|
import argparse |
|
import torch |
|
import lmdb |
|
import glob |
|
import os |
|
|
|
from utils.lmdb import store_arrays_to_lmdb, process_data_dict |
|
|
|
|
|
def main(): |
|
""" |
|
Aggregate all ode pairs inside a folder into a lmdb dataset. |
|
Each pt file should contain a (key, value) pair representing a |
|
video's ODE trajectories. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data_path", type=str, |
|
required=True, help="path to ode pairs") |
|
parser.add_argument("--lmdb_path", type=str, |
|
required=True, help="path to lmdb") |
|
parser.add_argument("--num_shards", type=int, |
|
default=16, help="num_shards") |
|
|
|
args = parser.parse_args() |
|
|
|
all_dirs = sorted(os.listdir(args.data_path)) |
|
|
|
|
|
map_size = int(1e12) |
|
os.makedirs(args.lmdb_path, exist_ok=True) |
|
|
|
envs = [] |
|
num_shards = args.num_shards |
|
for shard_id in range(num_shards): |
|
print("shard_id ", shard_id) |
|
path = os.path.join(args.lmdb_path, f"shard_{shard_id}") |
|
env = lmdb.open(path, |
|
map_size=map_size, |
|
subdir=True, |
|
readonly=False, |
|
metasync=True, |
|
sync=True, |
|
lock=True, |
|
readahead=False, |
|
meminit=False) |
|
envs.append(env) |
|
|
|
counters = [0] * num_shards |
|
seen_prompts = set() |
|
total_samples = 0 |
|
all_files = [] |
|
|
|
for part_dir in all_dirs: |
|
all_files += sorted(glob.glob(os.path.join(args.data_path, part_dir, "*.pt"))) |
|
|
|
|
|
for idx, file in tqdm(enumerate(all_files)): |
|
try: |
|
data_dict = torch.load(file) |
|
data_dict = process_data_dict(data_dict, seen_prompts) |
|
except Exception as e: |
|
print(f"Error processing {file}: {e}") |
|
continue |
|
|
|
if data_dict["latents"].shape != (1, 21, 16, 60, 104): |
|
continue |
|
|
|
shard_id = idx % num_shards |
|
|
|
store_arrays_to_lmdb(envs[shard_id], data_dict, start_index=counters[shard_id]) |
|
counters[shard_id] += len(data_dict['prompts']) |
|
data_shape = data_dict["latents"].shape |
|
|
|
total_samples += len(all_files) |
|
|
|
print(len(seen_prompts)) |
|
|
|
|
|
for shard_id, env in enumerate(envs): |
|
with env.begin(write=True) as txn: |
|
for key, val in (data_dict.items()): |
|
assert len(data_shape) == 5 |
|
array_shape = np.array(data_shape) |
|
array_shape[0] = counters[shard_id] |
|
shape_key = f"{key}_shape".encode() |
|
print(shape_key, array_shape) |
|
shape_str = " ".join(map(str, array_shape)) |
|
txn.put(shape_key, shape_str.encode()) |
|
|
|
print(f"Finished writing {total_samples} examples into {num_shards} shards under {args.lmdb_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|