import os.path import random import re import unicodedata import torch from torch.utils.data import Dataset from PIL import Image from typing import List, Union def clean_filename(s): # 去除首尾空格和点号 s = s.strip().strip('.') # 转换 Unicode 字符为 ASCII 形式 s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') illegal_chars = r'[/]' reserved_names = set() # 替换非法字符为下划线 s = re.sub(illegal_chars, '_', s) # 合并连续的下划线 s = re.sub(r'_{2,}', '_', s) # 转换为小写 s = s.lower() # 检查是否为保留文件名 if s.upper() in reserved_names: s = s + '_' # 限制文件名长度 max_length = 200 s = s[:max_length] if not s: return 'untitled' return s def save_fn(image, metadata, root_path): image_path = os.path.join(root_path, str(metadata['filename'])+".png") Image.fromarray(image).save(image_path) class RandomNDataset(Dataset): def __init__(self, latent_shape=(4, 64, 64), conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): if isinstance(conditions, int): conditions = list(range(conditions)) # class labels elif isinstance(conditions, str): if os.path.exists(conditions): conditions = open(conditions, "r").read().splitlines() else: raise FileNotFoundError(conditions) elif isinstance(conditions, list): conditions = conditions self.conditions = conditions self.num_conditons = len(conditions) self.seeds = seeds if num_samples_per_instance > 0: max_num_instances = num_samples_per_instance*self.num_conditons else: max_num_instances = max_num_instances if seeds is not None: self.max_num_instances = len(seeds)*self.num_conditons self.num_seeds = len(seeds) else: self.num_seeds = (max_num_instances + self.num_conditons - 1) // self.num_conditons self.max_num_instances = self.num_seeds*self.num_conditons self.latent_shape = latent_shape def __getitem__(self, idx): condition = self.conditions[idx//self.num_seeds] seed = random.randint(0, 1<<31) #idx % self.num_seeds if self.seeds is not None: seed = self.seeds[idx % self.num_seeds] filename = f"{clean_filename(str(condition))}_{seed}" generator = torch.Generator().manual_seed(seed) latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) metadata = dict( filename=filename, seed=seed, condition=condition, save_fn=save_fn, ) return latent, condition, metadata def __len__(self): return self.max_num_instances class ClassLabelRandomNDataset(RandomNDataset): def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): if conditions is None: conditions = list(range(num_classes)) super().__init__(latent_shape, conditions, seeds, max_num_instances, num_samples_per_instance)