Source code for torchnlp.samplers.distributed_batch_sampler

from torch.utils.data.sampler import BatchSampler

from torchnlp.samplers.distributed_sampler import DistributedSampler


[docs]class DistributedBatchSampler(BatchSampler): """ `BatchSampler` wrapper that distributes across each batch multiple workers. Args: batch_sampler (torch.utils.data.sampler.BatchSampler) num_replicas (int, optional): Number of processes participating in distributed training. rank (int, optional): Rank of the current process within num_replicas. Example: >>> from torch.utils.data.sampler import BatchSampler >>> from torch.utils.data.sampler import SequentialSampler >>> sampler = SequentialSampler(list(range(12))) >>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False) >>> >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0)) [[0, 2], [4, 6], [8, 10]] >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1)) [[1, 3], [5, 7], [9, 11]] """ def __init__(self, batch_sampler, **kwargs): self.batch_sampler = batch_sampler self.kwargs = kwargs def __iter__(self): for batch in self.batch_sampler: yield list(DistributedSampler(batch, **self.kwargs)) def __len__(self): return len(self.batch_sampler)