Source code for torchnlp.samplers.bptt_sampler

import math

from torch.utils.data.sampler import Sampler


[docs]class BPTTSampler(Sampler): """ Samples sequentially source and target slices of size ``bptt_length``. Typically, such a sampler, is used for language modeling training with backpropagation through time (BPTT). **Reference:** https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L122 Args: data (iterable): Iterable data. bptt_length (int): Length of the slice. type_ (str, optional): Type of slice ['source'|'target'] to load where a target slice is one timestep ahead Example: >>> from torchnlp.samplers import BPTTSampler >>> list(BPTTSampler(range(5), 2)) [slice(0, 2, None), slice(2, 4, None)] """ def __init__(self, data, bptt_length, type_='source'): self.data = data self.bptt_length = bptt_length self.type = type_ def __iter__(self): for i in range(0, len(self.data) - 1, self.bptt_length): seq_length = min(self.bptt_length, len(self.data) - 1 - i) if self.type == 'source': yield slice(i, i + seq_length) if self.type == 'target': yield slice(i + 1, i + 1 + seq_length) def __len__(self): return math.ceil((len(self.data) - 1) / self.bptt_length)