Source code for torchnlp.encoders.label_encoder

from collections import Counter

from torchnlp.encoders.encoder import Encoder

import torch

DEFAULT_UNKNOWN_TOKEN = '<unk>'
DEFAULT_RESERVED = [DEFAULT_UNKNOWN_TOKEN]


[docs]class LabelEncoder(Encoder): """ Encodes an label via a dictionary. Args: sample (list of strings): Sample of data used to build encoding dictionary. min_occurrences (int, optional): Minimum number of occurrences for a label to be added to the encoding dictionary. reserved_labels (list, optional): List of reserved labels inserted in the beginning of the dictionary. unknown_index (int, optional): The unknown label is used to encode unseen labels. This is the index that label resides at. **kwargs: Keyword arguments passed onto ``Encoder``. Example: >>> samples = ['label_a', 'label_b'] >>> encoder = LabelEncoder(samples, reserved_labels=['unknown'], unknown_index=0) >>> encoder.encode('label_a') tensor(1) >>> encoder.decode(encoder.encode('label_a')) 'label_a' >>> encoder.encode('label_c') tensor(0) >>> encoder.decode(encoder.encode('label_c')) 'unknown' >>> encoder.vocab ['unknown', 'label_a', 'label_b'] """ def __init__(self, sample, min_occurrences=1, reserved_labels=DEFAULT_RESERVED, unknown_index=DEFAULT_RESERVED.index(DEFAULT_UNKNOWN_TOKEN), **kwargs): super().__init__(**kwargs) if unknown_index and unknown_index >= len(reserved_labels): raise ValueError('The `unknown_index` if provided must be also `reserved`.') self.unknown_index = unknown_index self.tokens = Counter(sample) self.index_to_token = reserved_labels.copy() self.token_to_index = {token: index for index, token in enumerate(reserved_labels)} for token, count in self.tokens.items(): if count >= min_occurrences: self.index_to_token.append(token) self.token_to_index[token] = len(self.index_to_token) - 1 @property def vocab(self): """ Returns: list: List of labels in the dictionary. """ return self.index_to_token @property def vocab_size(self): """ Returns: int: Number of labels in the dictionary. """ return len(self.vocab)
[docs] def encode(self, label): """ Encodes a ``label``. Args: label (object): Label to encode. Returns: torch.Tensor: Encoding of the label. """ label = super().encode(label) return torch.tensor(self.token_to_index.get(label, self.unknown_index), dtype=torch.long)
[docs] def batch_encode(self, iterator, *args, dim=0, **kwargs): """ Args: iterator (iterator): Batch of labels to encode. *args: Arguments passed to ``Encoder.batch_encode``. dim (int, optional): Dimension along which to concatenate tensors. **kwargs: Keyword arguments passed to ``Encoder.batch_encode``. Returns: torch.Tensor: Tensor of encoded labels. """ return torch.stack(super().batch_encode(iterator, *args, **kwargs), dim=dim)
[docs] def decode(self, encoded): """ Decodes ``encoded`` label. Args: encoded (torch.Tensor): Encoded label. Returns: object: Label decoded from ``encoded``. """ encoded = super().decode(encoded) if encoded.numel() > 1: raise ValueError( '``decode`` decodes one label at a time, use ``batch_decode`` instead.') return self.index_to_token[encoded.squeeze().item()]
[docs] def batch_decode(self, tensor, *args, dim=0, **kwargs): """ Args: tensor (torch.Tensor): Batch of tensors. *args: Arguments passed to ``Encoder.batch_decode``. dim (int, optional): Dimension along which to split tensors. **kwargs: Keyword arguments passed to ``Encoder.batch_decode``. Returns: list: Batch of decoded labels. """ return super().batch_decode([t.squeeze(0) for t in tensor.split(1, dim=dim)])