Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any | |
import torch | |
import time | |
import copy | |
import lightning.pytorch as pl | |
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS | |
from torch.utils.data import DataLoader, Dataset, IterableDataset | |
from src.data.dataset.randn import RandomNDataset | |
def mirco_batch_collate_fn(batch): | |
batch = copy.deepcopy(batch) | |
new_batch = [] | |
for micro_batch in batch: | |
new_batch.extend(micro_batch) | |
x, y, metadata = list(zip(*new_batch)) | |
stacked_metadata = {} | |
for key in metadata[0].keys(): | |
try: | |
if isinstance(metadata[0][key], torch.Tensor): | |
stacked_metadata[key] = torch.stack([m[key] for m in metadata], dim=0) | |
else: | |
stacked_metadata[key] = [m[key] for m in metadata] | |
except: | |
pass | |
x = torch.stack(x, dim=0) | |
return x, y, stacked_metadata | |
def collate_fn(batch): | |
batch = copy.deepcopy(batch) | |
x, y, metadata = list(zip(*batch)) | |
stacked_metadata = {} | |
for key in metadata[0].keys(): | |
try: | |
if isinstance(metadata[0][key], torch.Tensor): | |
stacked_metadata[key] = torch.stack([m[key] for m in metadata], dim=0) | |
else: | |
stacked_metadata[key] = [m[key] for m in metadata] | |
except: | |
pass | |
x = torch.stack(x, dim=0) | |
return x, y, stacked_metadata | |
def eval_collate_fn(batch): | |
batch = copy.deepcopy(batch) | |
x, y, metadata = list(zip(*batch)) | |
x = torch.stack(x, dim=0) | |
return x, y, metadata | |
class DataModule(pl.LightningDataModule): | |
def __init__(self, | |
train_dataset:Dataset=None, | |
eval_dataset:Dataset=None, | |
pred_dataset:Dataset=None, | |
train_batch_size=64, | |
train_num_workers=16, | |
train_prefetch_factor=8, | |
eval_batch_size=32, | |
eval_num_workers=4, | |
pred_batch_size=32, | |
pred_num_workers=4, | |
): | |
super().__init__() | |
self.train_dataset = train_dataset | |
self.eval_dataset = eval_dataset | |
self.pred_dataset = pred_dataset | |
# stupid data_convert override, just to make nebular happy | |
self.train_batch_size = train_batch_size | |
self.train_num_workers = train_num_workers | |
self.train_prefetch_factor = train_prefetch_factor | |
self.eval_batch_size = eval_batch_size | |
self.pred_batch_size = pred_batch_size | |
self.pred_num_workers = pred_num_workers | |
self.eval_num_workers = eval_num_workers | |
self._train_dataloader = None | |
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: | |
return batch | |
def train_dataloader(self) -> TRAIN_DATALOADERS: | |
micro_batch_size = getattr(self.train_dataset, "micro_batch_size", None) | |
if micro_batch_size is not None: | |
assert self.train_batch_size % micro_batch_size == 0 | |
dataloader_batch_size = self.train_batch_size // micro_batch_size | |
train_collate_fn = mirco_batch_collate_fn | |
else: | |
dataloader_batch_size = self.train_batch_size | |
train_collate_fn = collate_fn | |
# build dataloader sampler | |
if not isinstance(self.train_dataset, IterableDataset): | |
sampler = torch.utils.data.distributed.DistributedSampler(self.train_dataset) | |
else: | |
sampler = None | |
self._train_dataloader = DataLoader( | |
self.train_dataset, | |
dataloader_batch_size, | |
timeout=6000, | |
num_workers=self.train_num_workers, | |
prefetch_factor=self.train_prefetch_factor, | |
collate_fn=train_collate_fn, | |
sampler=sampler, | |
) | |
return self._train_dataloader | |
def val_dataloader(self) -> EVAL_DATALOADERS: | |
global_rank = self.trainer.global_rank | |
world_size = self.trainer.world_size | |
from torch.utils.data import DistributedSampler | |
sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) | |
return DataLoader(self.eval_dataset, self.eval_batch_size, | |
num_workers=self.eval_num_workers, | |
prefetch_factor=2, | |
sampler=sampler, | |
collate_fn=eval_collate_fn | |
) | |
def predict_dataloader(self) -> EVAL_DATALOADERS: | |
global_rank = self.trainer.global_rank | |
world_size = self.trainer.world_size | |
from torch.utils.data import DistributedSampler | |
sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) | |
return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, | |
num_workers=self.pred_num_workers, | |
prefetch_factor=4, | |
sampler=sampler, | |
collate_fn=eval_collate_fn | |
) | |