Source code for torchnlp.samplers.balanced_sampler

from torchnlp._third_party.weighted_random_sampler import WeightedRandomSampler

from torchnlp.utils import identity


[docs]class BalancedSampler(WeightedRandomSampler): """ Weighted sampler with respect for an element's class. Args: data (iterable) get_class (callable, optional): Get the class of an item relative to the entire dataset. get_weight (callable, optional): Define a weight for each item other than one. kwargs: Additional key word arguments passed onto `WeightedRandomSampler`. Example: >>> from torchnlp.samplers import DeterministicSampler >>> >>> data = ['a', 'b', 'c'] + ['c'] * 100 >>> sampler = BalancedSampler(data, num_samples=3) >>> sampler = DeterministicSampler(sampler, random_seed=12) >>> [data[i] for i in sampler] ['c', 'b', 'a'] """ def __init__(self, data_source, get_class=identity, get_weight=lambda x: 1, **kwargs): classified = [get_class(item) for item in data_source] weighted = [float(get_weight(item)) for item in data_source] class_totals = { k: sum([w for c, w in zip(classified, weighted) if k == c]) for k in set(classified) } weights = [w / class_totals[c] if w > 0 else 0.0 for c, w in zip(classified, weighted)] super().__init__(weights=weights, **kwargs)