| | import numpy as np |
| | import torch |
| |
|
| | from seq2struct.utils import registry |
| | from seq2struct.models import transformer |
| |
|
| |
|
| | def maybe_mask(attn, attn_mask): |
| | if attn_mask is not None: |
| | assert all( |
| | a == 1 or b == 1 or a == b |
| | for a, b in zip(attn.shape[::-1], attn_mask.shape[::-1])), \ |
| | 'Attention mask shape {} should be broadcastable with attention shape {}'.format( |
| | attn_mask.shape, attn.shape) |
| |
|
| | attn.data.masked_fill_(attn_mask, -float('inf')) |
| |
|
| |
|
| | class Attention(torch.nn.Module): |
| | def __init__(self, pointer): |
| | super().__init__() |
| | self.pointer = pointer |
| | self.softmax = torch.nn.Softmax(dim=-1) |
| |
|
| | def forward(self, query, values, attn_mask=None): |
| | |
| | |
| |
|
| | |
| | attn_logits = self.pointer(query, values, attn_mask) |
| | |
| | attn = self.softmax(attn_logits) |
| | |
| | output = torch.bmm(attn.unsqueeze(1), values) |
| | output = output.squeeze(1) |
| | return output, attn |
| |
|
| |
|
| | @registry.register('pointer', 'sdp') |
| | class ScaledDotProductPointer(torch.nn.Module): |
| | def __init__(self, query_size, key_size): |
| | super().__init__() |
| | self.query_proj = torch.nn.Linear(query_size, key_size) |
| | self.temp = np.power(key_size, 0.5) |
| | |
| | def forward(self, query, keys, attn_mask=None): |
| | |
| | |
| |
|
| | |
| | proj_query = self.query_proj(query).unsqueeze(2) |
| | |
| | |
| | attn_logits = torch.bmm(keys, proj_query).squeeze(2) / self.temp |
| | maybe_mask(attn_logits, attn_mask) |
| | return attn_logits |
| |
|
| |
|
| | @registry.register('attention', 'sdp') |
| | class ScaledDotProductAttention(Attention): |
| | def __init__(self, query_size, value_size): |
| | super().__init__(ScaledDotProductPointer(query_size, value_size)) |
| |
|
| |
|
| | @registry.register('pointer', 'bahdanau') |
| | class BahdanauPointer(torch.nn.Module): |
| | def __init__(self, query_size, key_size, proj_size): |
| | super().__init__() |
| | self.compute_scores = torch.nn.Sequential( |
| | torch.nn.Linear(query_size + key_size, proj_size), |
| | torch.nn.Tanh(), |
| | torch.nn.Linear(proj_size, 1)) |
| | |
| | def forward(self, query: torch.Tensor, keys: torch.Tensor, attn_mask=None): |
| | |
| | |
| |
|
| | |
| | query_expanded = query.unsqueeze(1).expand(-1, keys.shape[1], -1) |
| |
|
| | |
| | attn_logits = self.compute_scores( |
| | |
| | torch.cat((query_expanded, keys), |
| | dim=2)) |
| | |
| | attn_logits = attn_logits.squeeze(2) |
| | maybe_mask(attn_logits, attn_mask) |
| | return attn_logits |
| |
|
| |
|
| | @registry.register('attention', 'bahdanau') |
| | class BahdanauAttention(Attention): |
| | def __init__(self, query_size, value_size, proj_size): |
| | super().__init__(BahdanauPointer(query_size, value_size, proj_size)) |
| |
|
| |
|
| | |
| | class MultiHeadedAttention(torch.nn.Module): |
| | def __init__(self, h, query_size, value_size, dropout=0.1): |
| | super().__init__() |
| | assert query_size % h == 0 |
| | assert value_size % h == 0 |
| |
|
| | |
| | self.d_k = value_size // h |
| | self.h = h |
| |
|
| | self.linears = torch.nn.ModuleList([ |
| | torch.nn.Linear(query_size, value_size), |
| | torch.nn.Linear(value_size, value_size), |
| | torch.nn.Linear(value_size, value_size), |
| | torch.nn.Linear(value_size, value_size), |
| | ]) |
| |
|
| | self.attn = None |
| | self.dropout = torch.nn.Dropout(p=dropout) |
| | |
| | def forward(self, query, values, attn_mask=None): |
| | "Implements Figure 2" |
| | if attn_mask is not None: |
| | |
| | attn_mask = attn_mask.unsqueeze(1) |
| | nbatches = query.size(0) |
| | |
| | |
| | query, keys, values = \ |
| | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) |
| | for l, x in zip(self.linears, (query, values, values))] |
| | |
| | |
| | |
| | x, self.attn = transformer.attention( |
| | query, keys, values, mask=attn_mask, dropout=self.dropout) |
| | |
| | |
| | x = x.transpose(1, 2).contiguous() \ |
| | .view(nbatches, -1, self.h * self.d_k) |
| | x = x.squeeze(1) |
| | return self.linears[-1](x), self.attn |
| |
|