| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def onehot(indexes, N=None): |
| | """ |
| | Creates a one-representation of indexes with N possible entries |
| | if N is not specified, it will suit the maximum index appearing. |
| | indexes is a long-tensor of indexes |
| | """ |
| | if N is None: |
| | N = indexes.max() + 1 |
| | sz = list(indexes.size()) |
| | output = indexes.new().long().resize_(*sz, N).zero_() |
| | output.scatter_(-1, indexes.unsqueeze(-1), 1) |
| | return output |
| |
|
| |
|
| | class SmoothedCrossEntropyLoss(nn.Module): |
| | def __init__(self, reduction='mean'): |
| | super(SmoothedCrossEntropyLoss, self).__init__() |
| | self.reduction = reduction |
| |
|
| | def forward(self, logits, labels, smooth_eps=0.1, mask=None, from_logits=True): |
| | """ |
| | Args: |
| | logits: (N, Lv), unnormalized probabilities, torch.float32 |
| | labels: (N, Lv) or (N, ), one hot labels or indices labels, torch.float32 or torch.int64 |
| | smooth_eps: float |
| | mask: (N, Lv) |
| | from_logits: bool |
| | """ |
| | if from_logits: |
| | probs = F.log_softmax(logits, dim=-1) |
| | else: |
| | probs = logits |
| | num_classes = probs.size()[-1] |
| | if len(probs.size()) > len(labels.size()): |
| | labels = onehot(labels, num_classes).type(probs.dtype) |
| | if mask is None: |
| | labels = labels * (1 - smooth_eps) + smooth_eps / num_classes |
| | else: |
| | mask = mask.type(probs.dtype) |
| | valid_samples = torch.sum(mask, dim=-1, keepdim=True, dtype=probs.dtype) |
| | eps_per_sample = smooth_eps / valid_samples |
| | labels = (labels * (1 - smooth_eps) + eps_per_sample) * mask |
| | loss = -torch.sum(labels * probs, dim=-1) |
| | if self.reduction == 'sum': |
| | return torch.sum(loss) |
| | elif self.reduction == 'mean': |
| | return torch.mean(loss) |
| | else: |
| | return loss |
| |
|
| |
|
| | class MILNCELoss(nn.Module): |
| | def __init__(self, reduction='mean'): |
| | super(MILNCELoss, self).__init__() |
| | self.reduction = reduction |
| |
|
| | def forward(self, q2ctx_scores=None, contexts=None, queries=None): |
| | if q2ctx_scores is None: |
| | assert contexts is not None and queries is not None |
| | x = torch.matmul(contexts, queries.t()) |
| | device = contexts.device |
| | bsz = contexts.shape[0] |
| | else: |
| | x = q2ctx_scores |
| | device = q2ctx_scores.device |
| | bsz = q2ctx_scores.shape[0] |
| | x = x.view(bsz, bsz, -1) |
| | nominator = x * torch.eye(x.shape[0], dtype=torch.float32, device=device)[:, :, None] |
| | nominator = nominator.sum(dim=1) |
| | nominator = torch.logsumexp(nominator, dim=1) |
| | denominator = torch.cat((x, x.permute(1, 0, 2)), dim=1).view(x.shape[0], -1) |
| | denominator = torch.logsumexp(denominator, dim=1) |
| | if self.reduction: |
| | return torch.mean(denominator - nominator) |
| | else: |
| | return denominator - nominator |
| |
|
| |
|
| | class DepthwiseSeparableConv(nn.Module): |
| | """ |
| | Depth-wise separable convolution uses less parameters to generate output by convolution. |
| | :Examples: |
| | >>> m = DepthwiseSeparableConv(300, 200, 5, dim=1) |
| | >>> input_tensor = torch.randn(32, 300, 20) |
| | >>> output = m(input_tensor) |
| | """ |
| | def __init__(self, in_ch, out_ch, k, dim=1, relu=True): |
| | """ |
| | :param in_ch: input hidden dimension size |
| | :param out_ch: output hidden dimension size |
| | :param k: kernel size |
| | :param dim: default 1. 1D conv or 2D conv |
| | """ |
| | super(DepthwiseSeparableConv, self).__init__() |
| | self.relu = relu |
| | if dim == 1: |
| | self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch, |
| | padding=k // 2) |
| | self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0) |
| | elif dim == 2: |
| | self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch, |
| | padding=k // 2) |
| | self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0) |
| | else: |
| | raise Exception("Incorrect dimension!") |
| |
|
| | def forward(self, x): |
| | """ |
| | :Input: (N, L_in, D) |
| | :Output: (N, L_out, D) |
| | """ |
| | x = x.transpose(1, 2) |
| | if self.relu: |
| | out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True) |
| | else: |
| | out = self.pointwise_conv(self.depthwise_conv(x)) |
| | return out.transpose(1, 2) |
| |
|
| |
|
| | class ConvEncoder(nn.Module): |
| | def __init__(self, kernel_size=7, n_filters=128, dropout=0.1): |
| | super(ConvEncoder, self).__init__() |
| | self.dropout = nn.Dropout(dropout) |
| | self.layer_norm = nn.LayerNorm(n_filters) |
| | self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True) |
| |
|
| | def forward(self, x): |
| | """ |
| | :param x: (N, L, D) |
| | :return: (N, L, D) |
| | """ |
| | return self.layer_norm(self.dropout(self.conv(x)) + x) |
| |
|
| |
|
| | class TrainablePositionalEncoding(nn.Module): |
| | """Construct the embeddings from word, position and token_type embeddings.""" |
| | def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): |
| | super(TrainablePositionalEncoding, self).__init__() |
| | self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) |
| | self.LayerNorm = nn.LayerNorm(hidden_size) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, input_feat): |
| | bsz, seq_length = input_feat.shape[:2] |
| | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) |
| | position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) |
| | position_embeddings = self.position_embeddings(position_ids) |
| | embeddings = self.LayerNorm(input_feat + position_embeddings) |
| | embeddings = self.dropout(embeddings) |
| | return embeddings |
| |
|
| | def add_position_emb(self, input_feat): |
| | bsz, seq_length = input_feat.shape[:2] |
| | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) |
| | position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) |
| | position_embeddings = self.position_embeddings(position_ids) |
| | return input_feat + position_embeddings |
| |
|
| |
|
| | class LinearLayer(nn.Module): |
| | """linear layer configurable with layer normalization, dropout, ReLU.""" |
| | def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): |
| | super(LinearLayer, self).__init__() |
| | self.relu = relu |
| | self.layer_norm = layer_norm |
| | if layer_norm: |
| | self.LayerNorm = nn.LayerNorm(in_hsz) |
| | layers = [nn.Dropout(dropout), nn.Linear(in_hsz, out_hsz)] |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | """(N, L, D)""" |
| | if self.layer_norm: |
| | x = self.LayerNorm(x) |
| | x = self.net(x) |
| | if self.relu: |
| | x = F.relu(x, inplace=True) |
| | return x |
| |
|
| |
|
| | class BertLayer(nn.Module): |
| | def __init__(self, config, use_self_attention=True): |
| | super(BertLayer, self).__init__() |
| | self.use_self_attention = use_self_attention |
| | if use_self_attention: |
| | self.attention = BertAttention(config) |
| | self.intermediate = BertIntermediate(config) |
| | self.output = BertOutput(config) |
| |
|
| | def forward(self, hidden_states, attention_mask): |
| | """ |
| | Args: |
| | hidden_states: (N, L, D) |
| | attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid |
| | """ |
| | if self.use_self_attention: |
| | attention_output = self.attention(hidden_states, attention_mask) |
| | else: |
| | attention_output = hidden_states |
| | intermediate_output = self.intermediate(attention_output) |
| | layer_output = self.output(intermediate_output, attention_output) |
| | return layer_output |
| |
|
| |
|
| | class BertAttention(nn.Module): |
| | def __init__(self, config): |
| | super(BertAttention, self).__init__() |
| | self.self = BertSelfAttention(config) |
| | self.output = BertSelfOutput(config) |
| |
|
| | def forward(self, input_tensor, attention_mask): |
| | """ |
| | Args: |
| | input_tensor: (N, L, D) |
| | attention_mask: (N, L) |
| | """ |
| | self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask) |
| | attention_output = self.output(self_output, input_tensor) |
| | return attention_output |
| |
|
| |
|
| | class BertIntermediate(nn.Module): |
| | def __init__(self, config): |
| | super(BertIntermediate, self).__init__() |
| | self.dense = nn.Sequential(nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(True)) |
| |
|
| | def forward(self, hidden_states): |
| | return self.dense(hidden_states) |
| |
|
| |
|
| | class BertOutput(nn.Module): |
| | def __init__(self, config): |
| | super(BertOutput, self).__init__() |
| | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|
| |
|
| | class BertSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super(BertSelfAttention, self).__init__() |
| | if config.hidden_size % config.num_attention_heads != 0: |
| | raise ValueError("The hidden size (%d) is not a multiple of the number of attention heads (%d)" % ( |
| | config.hidden_size, config.num_attention_heads)) |
| | self.num_attention_heads = config.num_attention_heads |
| | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| | self.all_head_size = self.num_attention_heads * self.attention_head_size |
| | self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| |
|
| | def transpose_for_scores(self, x): |
| | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| | x = x.view(*new_x_shape) |
| | return x.permute(0, 2, 1, 3) |
| |
|
| | def forward(self, query_states, key_states, value_states, attention_mask): |
| | """ |
| | Args: |
| | query_states: (N, Lq, D) |
| | key_states: (N, L, D) |
| | value_states: (N, L, D) |
| | attention_mask: (N, Lq, L) |
| | """ |
| | |
| | |
| | attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. |
| | mixed_query_layer = self.query(query_states) |
| | mixed_key_layer = self.key(key_states) |
| | mixed_value_layer = self.value(value_states) |
| | |
| | query_layer = self.transpose_for_scores(mixed_query_layer) |
| | key_layer = self.transpose_for_scores(mixed_key_layer) |
| | value_layer = self.transpose_for_scores(mixed_value_layer) |
| | |
| | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| | attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| | |
| | attention_scores = attention_scores + attention_mask |
| | |
| | attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| | |
| | |
| | attention_probs = self.dropout(attention_probs) |
| | |
| | context_layer = torch.matmul(attention_probs, value_layer) |
| | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | context_layer = context_layer.view(*new_context_layer_shape) |
| | return context_layer |
| |
|
| |
|
| | class BertSelfOutput(nn.Module): |
| | def __init__(self, config): |
| | super(BertSelfOutput, self).__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.LayerNorm = nn.LayerNorm(config.hidden_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | def forward(self, hidden_states, input_tensor): |
| | hidden_states = self.dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| | return hidden_states |
| |
|