|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| import torch |
| import torch.nn as nn |
|
|
| class TRMConfig(PretrainedConfig): |
| model_type = "recursive_gpt" |
|
|
| def __init__( |
| self, |
| vocab_size=50257, |
| n_positions=1024, |
| n_embd=512, |
| n_physical_layers=3, |
| n_loops=8, |
| n_head=8, |
| activation_function="gelu_new", |
| resid_pdrop=0.1, |
| embd_pdrop=0.1, |
| attn_pdrop=0.1, |
| layer_norm_epsilon=1e-5, |
| scale_attn_weights=True, |
| scale_attn_by_inverse_layer_idx=False, |
| reorder_and_upcast_attn=False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.n_positions = n_positions |
| self.n_embd = n_embd |
| self.n_physical_layers = n_physical_layers |
| self.n_loops = n_loops |
| self.n_head = n_head |
| self.activation_function = activation_function |
| self.resid_pdrop = resid_pdrop |
| self.embd_pdrop = embd_pdrop |
| self.attn_pdrop = attn_pdrop |
| self.layer_norm_epsilon = layer_norm_epsilon |
| self.scale_attn_weights = scale_attn_weights |
| self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
| self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
| |
| self.hidden_size = n_embd |
| self.num_attention_heads = n_head |
| self.num_hidden_layers = n_physical_layers |
| self.n_inner = None |
| self.is_encoder_decoder = False |
|
|
| class TinyRecursiveModel(PreTrainedModel, GenerationMixin): |
| config_class = TRMConfig |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
| self.drop = nn.Dropout(config.embd_pdrop) |
|
|
| |
| self.physical_blocks = nn.ModuleList([ |
| nn.ModuleDict({ |
| "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
| "attn": GPT2Attention(config, layer_idx=i), |
| "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
| "mlp": GPT2MLP(4 * config.n_embd, config) |
| }) for i in range(config.n_physical_layers) |
| ]) |
|
|
| |
| self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
| |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| if input_ids is None: |
| return None |
|
|
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
|
|
| |
| token_embeds = self.wte(input_ids) |
| pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) |
| pos_embeds = self.wpe(pos_ids) |
| hidden_states = self.drop(token_embeds + pos_embeds) |
|
|
| |
| for loop in range(self.config.n_loops): |
| block_idx = loop % self.config.n_physical_layers |
| block = self.physical_blocks[block_idx] |
|
|
| |
| ln_output = block["ln_1"](hidden_states) |
| attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0] |
| hidden_states = hidden_states + attn_output |
|
|
| |
| ln_output = block["ln_2"](hidden_states) |
| mlp_output = block["mlp"](ln_output) |
| hidden_states = hidden_states + mlp_output |
|
|
| |
| hidden_states = self.ln_f(hidden_states) |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=loss, |
| logits=logits, |
| hidden_states=hidden_states, |
| attentions=None, |
| cross_attentions=None |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| def _reorder_cache(self, past, beam_idx): |
| return past |
|
|