Source code for torchnlp.samplers.bptt_batch_sampler

import math

from torch.utils.data.sampler import Sampler

from torchnlp.samplers.bptt_sampler import BPTTSampler


[docs]class BPTTBatchSampler(Sampler): """ Samples sequentially a batch of source and target slices of size ``bptt_length``. Typically, such a sampler, is used for language modeling training with backpropagation through time (BPTT). **Reference:** https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L61 Args: data (iterable) bptt_length (int): Length of the slice. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``. type_ (str, optional): Type of batch ['source'|'target'] to load where a target batch is one timestep ahead. Example: >>> sampler = BPTTBatchSampler(range(100), bptt_length=2, batch_size=3, drop_last=False) >>> list(sampler)[0] # First Batch [slice(0, 2, None), slice(34, 36, None), slice(67, 69, None)] """ def __init__(self, data, bptt_length, batch_size, drop_last, type_='source'): self.data = data self.batch_size = batch_size self.drop_last = drop_last # For each row in the batch, we iterate over a chunk of size `chunk_size` # Our chunks are similar to the columns in this PyTorch example: # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L61 chunk_sizes = [math.floor(len(data) / batch_size)] * batch_size # Distribute the remaining elements to some chunks if not self.drop_last: remainder = len(data) - sum(chunk_sizes) for i in range(remainder): chunk_sizes[i] += 1 self.samplers = [{ 'offset': sum(chunk_sizes[:i]), 'sampler': BPTTSampler(range(chunk_sizes[i]), bptt_length, type_=type_) } for i in range(batch_size)] def __iter__(self): # Samplers iterate over chunks similar to: # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112 self.iterators = [iter(value['sampler']) for value in self.samplers] while True: batch = [] for i, iterator in enumerate(self.iterators): try: # Adjust the sampler indices to the offset offset = self.samplers[i]['offset'] slice_ = next(iterator) batch.append(slice(slice_.start + offset, slice_.stop + offset)) except StopIteration: pass # Samplers are all empty if (len(batch) == 0): break yield batch def __len__(self): return len(self.samplers[0]['sampler'])