| import torch |
|
|
| |
| from builtin_architecture import make_model |
| import os |
| import sys |
| import time |
| from dataset import dataset, get_train_dataset, get_dataloader |
| import torch.nn.functional as F |
| from tqdm import tqdm, trange |
| import heapq |
|
|
| EXPERIMENT_DIRECTORY = "runs/code-decoder-v23-mega" |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| device = "cpu" |
|
|
|
|
| def evaluate_topk(model, start_sequence, amt=10, k=20, temperature=1.0, device="cpu"): |
| generated_sequence = start_sequence.clone().to(device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topk"): |
| seq = generated_sequence |
| results = model(seq, transpose=True) |
| results = results.transpose(0, 1) |
|
|
| logits = results.reshape(-1, results.size(-1))[-1] |
|
|
| logits = logits / temperature |
|
|
| top_k_values, top_k_indices = torch.topk(logits, k) |
| top_k_probs = F.softmax(top_k_values, dim=-1) |
|
|
| sampled_index = torch.multinomial(top_k_probs, 1).item() |
| next_token = top_k_indices[sampled_index].unsqueeze(0) |
|
|
| generated_sequence = torch.cat( |
| (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| ) |
|
|
| return generated_sequence |
|
|
|
|
| def evaluate_topp(model, start_sequence, amt=10, p=0.9, temperature=1.0, device="cpu"): |
| generated_sequence = start_sequence.clone().to(device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| for _ in trange(amt, leave=False, dynamic_ncols=True, desc="topp"): |
| seq = generated_sequence |
| results = model(seq, transpose=True) |
| results = results.transpose(0, 1) |
|
|
| logits = results.reshape(-1, results.size(-1))[-1] |
| logits = logits / temperature |
|
|
| probs = F.softmax(logits, dim=-1) |
|
|
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
| cutoff_idx = torch.where(cumulative_probs > p)[0][0] + 1 |
| top_p_probs = sorted_probs[:cutoff_idx] |
| top_p_indices = sorted_indices[:cutoff_idx] |
|
|
| |
| top_p_probs /= top_p_probs.sum() |
|
|
| |
| sampled_index = torch.multinomial(top_p_probs, 1).item() |
| next_token = top_p_indices[sampled_index].unsqueeze(0) |
|
|
| generated_sequence = torch.cat( |
| (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| ) |
|
|
| return generated_sequence |
|
|
|
|
| def evaluate_beam(model, start_sequence, k=2, amt=10, temperature=0.8, device="cpu"): |
| generated_sequence = start_sequence.clone().to(device) |
|
|
| model.eval() |
|
|
| |
| current_beams = generated_sequence.expand(k, -1) |
| current_beam_scores = torch.zeros(k, device=device) |
|
|
| with torch.no_grad(): |
| for _ in trange(amt, leave=False, dynamic_ncols=True, desc="beam"): |
| all_candidates = [] |
|
|
| |
| for i in range(k): |
| seq = current_beams[i].unsqueeze(0) |
| results = model(seq, transpose=True) |
| results = results.transpose(0, 1) |
|
|
| logits = results[:, -1, :] / temperature |
| topk_values, topk_indices = torch.topk(logits, k) |
|
|
| |
| for j in range(k): |
| candidate = torch.cat((seq, topk_indices[:, j].unsqueeze(0)), dim=1) |
| score = current_beam_scores[i] + topk_values[:, j] |
| all_candidates.append((candidate, score)) |
|
|
| |
| all_candidates.sort(key=lambda x: x[1], reverse=True) |
| top_candidates = all_candidates[:k] |
|
|
| current_beams = torch.cat([candidate for candidate, _ in top_candidates]) |
| current_beam_scores = torch.tensor( |
| [score.item() for _, score in top_candidates], device=device |
| ) |
|
|
| return current_beams[0] |
|
|
|
|
| def evaluate( |
| model, |
| start_sequence, |
| amt=10, |
| ): |
| generated_sequence = start_sequence.clone() |
| generated_sequence = generated_sequence.to(device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| for _ in trange(amt, leave=False): |
| seq = generated_sequence |
| results = model(seq, transpose=True) |
| results = results.transpose(0, 1) |
|
|
| next_token = torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[ |
| -1 |
| ].unsqueeze(0) |
|
|
| generated_sequence = torch.cat( |
| (generated_sequence, next_token.unsqueeze(0)), dim=1 |
| ) |
|
|
| return generated_sequence |
|
|
|
|
| def tester_exactly_like_trainingmanager_please_please_work(model, rawbatch): |
| labels = rawbatch[:, 1:].contiguous() |
| batch = rawbatch[:, :-1].contiguous() |
| results = model(batch, transpose=True) |
| results = results.transpose(0, 1) |
| print( |
| torch.sum( |
| torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
| == labels.reshape(-1) |
| ) |
| / len(labels.reshape(-1)) |
| ) |
| return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
| -1 |
| ) |
|
|
|
|
| def tester_exactly_like_trainingmanager_only_last_please_work(model, rawbatch): |
| labels = rawbatch[:, 1:].contiguous() |
| batch = rawbatch[:, :-1].contiguous() |
|
|
| batch = batch[-1].unsqueeze(0) |
| labels = labels[-1].unsqueeze(0) |
|
|
| results = model(batch, transpose=True) |
| results = results.transpose(0, 1) |
| print( |
| torch.sum( |
| torch.argmax(results.reshape(-1, results.size(-1)), dim=1) |
| == labels.reshape(-1) |
| ) |
| / len(labels.reshape(-1)) |
| ) |
| return torch.argmax(results.reshape(-1, results.size(-1)), dim=1), labels.reshape( |
| -1 |
| ) |
|
|
| |
| |
|
|
| |
| |
|
|
| return torch.argmax(results.reshape(-1, results.size(-1)), dim=1)[-1] |
|
|
|
|
| def compute_entropy(logits): |
|
|
| probs = F.softmax(logits, dim=-1) |
| entropy = -(probs * probs.log()).sum(dim=-1) |
| return entropy.mean().item() |
|
|
|
|
| def main(): |
| |
| net = make_model() |
| net.to(device) |
| print(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt")) |
| net.load_state_dict( |
| torch.load( |
| os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "latest.pt"), weights_only=True |
| ) |
| ) |
|
|
| for name, param in net.named_parameters(): |
| if torch.isnan(param).any(): |
| print(f"NaN found in {name}") |
| for name, param in net.named_parameters(): |
| if param.grad is not None and torch.isnan(param.grad).any(): |
| print(f"NaN found in gradients of {name}") |
| loader = get_dataloader(get_train_dataset()) |
| torch.random.manual_seed( |
| sum([ord(i) for i in input("seed? ")]) |
| ) |
| for data in loader: |
| batch, attn_mask = data |
|
|
| print( |
| tester_exactly_like_trainingmanager_please_please_work(net, rawbatch=batch) |
| ) |
| print("pretty please") |
|
|
| print( |
| tester_exactly_like_trainingmanager_only_last_please_work( |
| net, rawbatch=batch |
| ) |
| ) |
| print("please please please") |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| labels = batch[:, 1:].contiguous() |
| batch = batch[:, :-1].contiguous() |
|
|
| batch = batch[0] |
| labels = labels[0] |
|
|
| batch = batch[:100] |
| labels = labels[:100] |
| print("Getting first 100 tokens for batch and labels") |
|
|
| |
|
|
| |
| print(batch) |
| print(dataset.manager.decode(batch)) |
| print("batch ^ labels v") |
| print(dataset.manager.decode(labels)) |
| print("that's inp I guess ^^") |
| with torch.no_grad(): |
| logits = net(batch.unsqueeze(0)) |
| entropy = compute_entropy( |
| logits[:, -1, :] |
| ) |
|
|
| print(f"Entropy of last token: {entropy:.4f}") |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print("USING BEAM") |
| result = evaluate_beam(net, batch.unsqueeze(0), amt=100, k=3) |
|
|
| result = dataset.manager.decode(result) |
| batch_str = dataset.manager.decode(batch) |
|
|
| result = f"<data>\n{batch_str}</data>\n{result[len(batch_str):]}" |
|
|
| print(result) |
|
|
| |
|
|
| break |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|