| import torch |
|
|
| |
| from builtin_architecture import make_model |
| import os |
| import sys |
| import time |
| from dataset import dataset, get_train_dataset |
| import torch.nn.functional as F |
|
|
| EXPERIMENT_DIRECTORY = "runs/code-decoder-v10-vanilla-smaller-batchfirst" |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| device = "cpu" |
|
|
| |
| net = make_model() |
| net.to(device) |
|
|
| net.load_state_dict( |
| torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "best.pt"), weights_only=True) |
| ) |
|
|
|
|
| for name, param in net.named_parameters(): |
| if torch.isnan(param).any(): |
| print(f"NaN found in {name}") |
| for name, param in net.named_parameters(): |
| if param.grad is not None and torch.isnan(param.grad).any(): |
| print(f"NaN found in gradients of {name}") |
|
|
|
|
| pad_token_id = 0 |
| sep_token_id = None |
|
|
| input_text = input("Prompt: ") |
| max_length = 100 |
|
|
|
|
| input_ids = torch.tensor(dataset.manager.encode(input_text), dtype=int) |
| print(input_ids.shape) |
| attention_mask = dataset.manager.attention_mask(input_ids.squeeze(0)).to(device) |
|
|
|
|
| generated_text = dataset.manager.decode(input_ids) |
|
|
| print(generated_text) |
| generated_text = "" |
| input_ids = torch.randint(199, (1, 1), dtype=torch.long).to(device) |
|
|
| net.eval() |
| temp = 1.0 |
|
|
| for _ in range(max_length): |
| with torch.no_grad(): |
| output = net(input_ids) |
| logits = F.log_softmax(output[-1], dim=-1) |
| word_weights = logits.div(temp).cpu() |
|
|
| |
| top_k = 10 |
| vocab_size = word_weights.size(0) |
| top_k = min(top_k, vocab_size) |
|
|
| top_probs, top_indices = torch.topk(word_weights, k=top_k) |
|
|
| |
| if top_probs.size(0) == 1: |
| word_idx = top_indices[0] |
| else: |
| sampled_idx = torch.multinomial(top_probs, 1).item() |
| word_idx = top_indices[sampled_idx] |
|
|
| |
| print(word_idx) |
| predicted_token = dataset.manager.decode(word_idx.item()) |
| print(predicted_token, end=" ") |
| generated_text += predicted_token |
|
|
| print("Word Weights:", word_weights) |
| print("Top Probabilities:", top_probs) |
| print("Top Indices:", top_indices) |
|
|
| |
| word_tensor = torch.tensor([[word_idx]], dtype=torch.long).to(device) |
| input_ids = torch.cat([input_ids, word_tensor], dim=1) |
|
|
| print("\nGenerated text:", generated_text) |
| with open("output.txt", "w+") as f: |
| f.write(generated_text) |
|
|