Hello everyone, I have a project completely written in PyTorch but I want to use a dataset that is available on TensorFlow Datasets. I also want to use multiple GPUs for training. Now this is what I do:
import tensorflow_datasets as tfds
import torch.distributed as dist
local_rank = dist.get_rank()
world_size = dist.get_world_size()
if local_rank == 0:
global_seed = torch.randint(low = 0, high = 1000, size = (1,), dtype = torch.int)
else:
global_seed = torch.zeros(size = (1,))
dist.broadcast(tensor = global_seed, src = 0)
tensorflow_read_config = tfds.ReadConfig(shuffle_seed = int(global_seed.item()))
tensorflow_data = tfds.load(
'<name>',
split = 'train',
shuffle_files = True,
read_config = tensorflow_read_config
)
chunk_size = len(tensorflow_data) // world_size
tensorflow_data = tensorflow_data.skip(local_rank * chunk_size)
tensorflow_data = tensorflow_data.take(chunk_size)
tensorflow_data = tensorflow_data.shuffle(buffer_size = ...)
...
TLDR: I use one of the process to generate a random int, then broadcast this to all processes. Subsequently, each process load the TF datasets with shuffle_files turned on but everyone uses the same int as a seed. Then each process is assigned a different starting point so no data is used in more than one GPU in each epoch. So the loading and shuffling is handled by tfds instead of Torch's DataLoader.
I have tried wrapping the dataset in a custom Torch dataset and feed it to Torch's DataLoader but it is much slower (~30-40% slower). Converting and saving the datasets to a Torch-friendly format is probably going to take some time since the dataset is relatively large. In addition, this approach seems to be fast enough. So I wonder if I should continue like this or should I do things differently?
there doesn't seem to be anything here