from torch.utils.data.sampler import BatchSampler
from torchnlp.samplers.distributed_sampler import DistributedSampler
[docs]class DistributedBatchSampler(BatchSampler):
""" `BatchSampler` wrapper that distributes across each batch multiple workers.
Args:
batch_sampler (torch.utils.data.sampler.BatchSampler)
num_replicas (int, optional): Number of processes participating in distributed training.
rank (int, optional): Rank of the current process within num_replicas.
Example:
>>> from torch.utils.data.sampler import BatchSampler
>>> from torch.utils.data.sampler import SequentialSampler
>>> sampler = SequentialSampler(list(range(12)))
>>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False)
>>>
>>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0))
[[0, 2], [4, 6], [8, 10]]
>>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1))
[[1, 3], [5, 7], [9, 11]]
"""
def __init__(self, batch_sampler, **kwargs):
self.batch_sampler = batch_sampler
self.kwargs = kwargs
def __iter__(self):
for batch in self.batch_sampler:
yield list(DistributedSampler(batch, **self.kwargs))
def __len__(self):
return len(self.batch_sampler)