Source code for torchnlp.nn.weight_drop

from torch.nn import Parameter

import torch


def _weight_drop(module, weights, dropout):
    """
    Helper for `WeightDrop`.
    """

    for name_w in weights:
        w = getattr(module, name_w)
        del module._parameters[name_w]
        module.register_parameter(name_w + '_raw', Parameter(w))

    original_module_forward = module.forward

    def forward(*args, **kwargs):
        for name_w in weights:
            raw_w = getattr(module, name_w + '_raw')
            w = torch.nn.functional.dropout(raw_w, p=dropout, training=module.training)
            setattr(module, name_w, w)

        return original_module_forward(*args, **kwargs)

    setattr(module, 'forward', forward)


[docs]class WeightDrop(torch.nn.Module): """ The weight-dropped module applies recurrent regularization through a DropConnect mask on the hidden-to-hidden recurrent weights. **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is their `License <https://github.com/salesforce/awd-lstm-lm/blob/master/LICENSE>`__. Args: module (:class:`torch.nn.Module`): Containing module. weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a dropout too. dropout (float): The probability a weight will be dropped. Example: >>> from torchnlp.nn import WeightDrop >>> import torch >>> >>> torch.manual_seed(123) <torch._C.Generator object ... >>> >>> gru = torch.nn.GRUCell(2, 2) >>> weights = ['weight_hh'] >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9) >>> >>> input_ = torch.randn(3, 2) >>> hidden_state = torch.randn(3, 2) >>> weight_drop_gru(input_, hidden_state) tensor(... grad_fn=<AddBackward0>) """ def __init__(self, module, weights, dropout=0.0): super(WeightDrop, self).__init__() _weight_drop(module, weights, dropout) self.forward = module.forward
[docs]class WeightDropLSTM(torch.nn.LSTM): """ Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument. Args: weight_dropout (float): The probability a weight will be dropped. """ def __init__(self, *args, weight_dropout=0.0, **kwargs): super().__init__(*args, **kwargs) weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] _weight_drop(self, weights, weight_dropout)
[docs]class WeightDropGRU(torch.nn.GRU): """ Wrapper around :class:`torch.nn.GRU` that adds ``weight_dropout`` named argument. Args: weight_dropout (float): The probability a weight will be dropped. """ def __init__(self, *args, weight_dropout=0.0, **kwargs): super().__init__(*args, **kwargs) weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] _weight_drop(self, weights, weight_dropout)
[docs]class WeightDropLinear(torch.nn.Linear): """ Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. Args: weight_dropout (float): The probability a weight will be dropped. """ def __init__(self, *args, weight_dropout=0.0, **kwargs): super().__init__(*args, **kwargs) weights = ['weight'] _weight_drop(self, weights, weight_dropout)