Source code for torchnlp.samplers.sorted_sampler

from torch.utils.data.sampler import Sampler

from torchnlp.utils import identity


[docs]class SortedSampler(Sampler): """ Samples elements sequentially, always in the same order. Args: data (iterable): Iterable data. sort_key (callable): Specifies a function of one argument that is used to extract a numerical comparison key from each list element. Example: >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] """ def __init__(self, data, sort_key=identity): super().__init__(data) self.data = data self.sort_key = sort_key zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] zip_ = sorted(zip_, key=lambda r: r[1]) self.sorted_indexes = [item[0] for item in zip_] def __iter__(self): return iter(self.sorted_indexes) def __len__(self): return len(self.data)