import torch
import torch.nn as nn
[docs]class Attention(nn.Module):
""" Applies attention mechanism on the `context` using the `query`.
**Thank you** to IBM for their initial implementation of :class:`Attention`. Here is
their `License
<https://github.com/IBM/pytorch-seq2seq/blob/master/LICENSE>`__.
Args:
dimensions (int): Dimensionality of the query and context.
attention_type (str, optional): How to compute the attention score:
* dot: :math:`score(H_j,q) = H_j^T q`
* general: :math:`score(H_j, q) = H_j^T W_a q`
Example:
>>> attention = Attention(256)
>>> query = torch.randn(5, 1, 256)
>>> context = torch.randn(5, 5, 256)
>>> output, weights = attention(query, context)
>>> output.size()
torch.Size([5, 1, 256])
>>> weights.size()
torch.Size([5, 1, 5])
"""
def __init__(self, dimensions, attention_type='general'):
super(Attention, self).__init__()
if attention_type not in ['dot', 'general']:
raise ValueError('Invalid attention type selected.')
self.attention_type = attention_type
if self.attention_type == 'general':
self.linear_in = nn.Linear(dimensions, dimensions, bias=False)
self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.tanh = nn.Tanh()
[docs] def forward(self, query, context):
"""
Args:
query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of
queries to query the context.
context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data
overwhich to apply the attention mechanism.
Returns:
:class:`tuple` with `output` and `weights`:
* **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]):
Tensor containing the attended features.
* **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]):
Tensor containing attention weights.
"""
batch_size, output_len, dimensions = query.size()
query_len = context.size(1)
if self.attention_type == "general":
query = query.reshape(batch_size * output_len, dimensions)
query = self.linear_in(query)
query = query.reshape(batch_size, output_len, dimensions)
# TODO: Include mask on PADDING_INDEX?
# (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, query_len)
attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
# Compute weights across every context sequence
attention_scores = attention_scores.view(batch_size * output_len, query_len)
attention_weights = self.softmax(attention_scores)
attention_weights = attention_weights.view(batch_size, output_len, query_len)
# (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
# (batch_size, output_len, dimensions)
mix = torch.bmm(attention_weights, context)
# concat -> (batch_size * output_len, 2*dimensions)
combined = torch.cat((mix, query), dim=2)
combined = combined.view(batch_size * output_len, 2 * dimensions)
# Apply linear_out on every 2nd dimension of concat
# output -> (batch_size, output_len, dimensions)
output = self.linear_out(combined).view(batch_size, output_len, dimensions)
output = self.tanh(output)
return output, attention_weights