Source code for torchnlp.encoders.text.text_encoder

import typing

import torch

from torchnlp.encoders import Encoder
from torchnlp.encoders.text.default_reserved_tokens import DEFAULT_PADDING_INDEX


[docs]def pad_tensor(tensor, length, padding_index=DEFAULT_PADDING_INDEX): """ Pad a ``tensor`` to ``length`` with ``padding_index``. Args: tensor (torch.Tensor [n, ...]): Tensor to pad. length (int): Pad the ``tensor`` up to ``length``. padding_index (int, optional): Index to pad tensor with. Returns (torch.Tensor [length, ...]) Padded Tensor. """ n_padding = length - tensor.shape[0] assert n_padding >= 0 if n_padding == 0: return tensor padding = tensor.new(n_padding, *tensor.shape[1:]).fill_(padding_index) return torch.cat((tensor, padding), dim=0)
[docs]class SequenceBatch(typing.NamedTuple): tensor: torch.Tensor lengths: torch.Tensor
[docs]def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0): """ Pad a :class:`list` of ``tensors`` (``batch``) with ``padding_index``. Args: batch (:class:`list` of :class:`torch.Tensor`): Batch of tensors to pad. padding_index (int, optional): Index to pad tensors with. dim (int, optional): Dimension on to which to concatenate the batch of tensors. Returns SequenceBatch: Padded tensors and original lengths of tensors. """ lengths = [tensor.shape[0] for tensor in batch] max_len = max(lengths) padded = [pad_tensor(tensor, max_len, padding_index) for tensor in batch] lengths = torch.tensor(lengths, dtype=torch.long) padded = torch.stack(padded, dim=dim).contiguous() for _ in range(dim): lengths = lengths.unsqueeze(0) return SequenceBatch(padded, lengths)
[docs]class TextEncoder(Encoder):
[docs] def decode(self, encoded): """ Decodes an object. Args: object_ (object): Encoded object. Returns: object: Object decoded. """ if self.enforce_reversible: self.enforce_reversible = False decoded_encoded = self.encode(self.decode(encoded)) self.enforce_reversible = True if not torch.equal(decoded_encoded, encoded): raise ValueError('Decoding is not reversible for "%s"' % encoded) return encoded
[docs] def batch_encode(self, iterator, *args, dim=0, **kwargs): """ Args: iterator (iterator): Batch of text to encode. *args: Arguments passed onto ``Encoder.__init__``. dim (int, optional): Dimension along which to concatenate tensors. **kwargs: Keyword arguments passed onto ``Encoder.__init__``. Returns torch.Tensor, torch.Tensor: Encoded and padded batch of sequences; Original lengths of sequences. """ return stack_and_pad_tensors( super().batch_encode(iterator), padding_index=self.padding_index, dim=dim)
[docs] def batch_decode(self, tensor, lengths, dim=0, *args, **kwargs): """ Args: batch (list of :class:`torch.Tensor`): Batch of encoded sequences. lengths (torch.Tensor): Original lengths of sequences. dim (int, optional): Dimension along which to split tensors. *args: Arguments passed to ``decode``. **kwargs: Key word arguments passed to ``decode``. Returns: list: Batch of decoded sequences. """ return super().batch_decode( [t.squeeze(0)[:l] for t, l in zip(tensor.split(1, dim=dim), lengths)])