Source code for torchnlp.samplers.oom_batch_sampler

import heapq

from torch.utils.data.sampler import BatchSampler

from torchnlp.utils import get_tensors


[docs]def get_number_of_elements(object_): """ Get the sum of the number of elements in all tensors stored in `object_`. This is particularly useful for sampling the largest objects based on tensor size like in: `OomBatchSampler.__init__.get_item_size`. Args: object (any) Returns: (int): The number of elements in the `object_`. """ return sum([t.numel() for t in get_tensors(object_)])
[docs]class OomBatchSampler(BatchSampler): """ Out-of-memory (OOM) batch sampler wraps `batch_sampler` to sample the `num_batches` largest batches first in attempt to cause any potential OOM errors early. Credits: https://github.com/allenai/allennlp/blob/3d100d31cc8d87efcf95c0b8d162bfce55c64926/allennlp/data/iterators/bucket_iterator.py#L43 Args: batch_sampler (torch.utils.data.sampler.BatchSampler) get_item_size (callable): Measure the size of an item given it's index `int`. num_batches (int, optional): The number of the large batches to move to the beginning of the iteration. """ def __init__(self, batch_sampler, get_item_size, num_batches=5): self.batch_sampler = batch_sampler self.get_item_size = get_item_size self.num_batches = num_batches def __iter__(self): batches = list(iter(self.batch_sampler)) largest_batches = heapq.nlargest( self.num_batches, range(len(batches)), key=lambda i: sum([self.get_item_size(j) for j in batches[i]])) move_to_front = [batches[i] for i in largest_batches] [batches.pop(i) for i in sorted(largest_batches, reverse=True)] batches[0:0] = move_to_front return iter(batches) def __len__(self): return len(self.batch_sampler)