| import torch.nn as nn |
|
|
| from modules.token_embedders.bert_encoder import BertEncoder |
| from utils.nn_utils import batched_index_select, gelu |
|
|
|
|
| class BertEmbedModel(nn.Module): |
| """This class acts as an embeddding layer with bert model |
| """ |
| def __init__(self, cfg, vocab, rel_mlp=False): |
| """This function constructs `BertEmbedModel` components and |
| sets `BertEmbedModel` parameters |
| |
| Arguments: |
| cfg {dict} -- config parameters for constructing multiple models |
| vocab {Vocabulary} -- vocabulary |
| """ |
|
|
| super().__init__() |
| self.rel_mlp = rel_mlp |
| self.activation = gelu |
| self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, |
| trainable=cfg.fine_tune, |
| output_size=cfg.bert_output_size, |
| activation=self.activation, |
| dropout=cfg.bert_dropout) |
| self.encoder_output_size = self.bert_encoder.get_output_dims() |
|
|
| def forward(self, batch_inputs): |
| """This function propagetes forwardly |
| |
| Arguments: |
| batch_inputs {dict} -- batch input data |
| """ |
|
|
| if 'wordpiece_segment_ids' in batch_inputs: |
| batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( |
| batch_inputs['wordpiece_tokens'], batch_inputs['wordpiece_segment_ids']) |
| else: |
| batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( |
| batch_inputs['wordpiece_tokens']) |
| |
| if not self.rel_mlp: |
| batch_seq_tokens_encoder_repr = batched_index_select(batch_seq_bert_encoder_repr, |
| batch_inputs['wordpiece_tokens_index']) |
| batch_inputs['seq_encoder_reprs'] = batch_seq_tokens_encoder_repr |
| else: |
| batch_inputs['seq_encoder_reprs'] = batch_seq_bert_encoder_repr |
| |
| |
| |
| batch_inputs['seq_cls_repr'] = batch_cls_repr |
|
|
| def get_hidden_size(self): |
| """This function returns embedding dimensions |
| |
| Returns: |
| int -- embedding dimensitons |
| """ |
|
|
| return self.encoder_output_size |
|
|