| import os |
| import torch |
| from logger import log_data, init_logger, log_img |
| import torch.nn as nn |
| from tqdm import tqdm, trange |
| from torch.profiler import profile, record_function, ProfilerActivity |
| import gc |
| import numpy as np |
| from eval import evaluate_topk |
| from dataset import dataset |
| from Levenshtein import ratio |
| from enum import Enum |
| import signal |
| import sys |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
| from collections import defaultdict |
|
|
|
|
| class ValueTracker: |
| def __init__(self): |
| self.data = {} |
|
|
| def add(self, label, value): |
| if label not in self.data: |
| self.data[label] = [] |
| self.data[label].append(value) |
|
|
| def average(self, label): |
| values = self.data[label] |
| if values: |
| return sum(values) / len(values) |
| else: |
| return 0.0 |
|
|
| def reset(self, label=None): |
| if label is not None: |
| if label in self.data: |
| self.data[label] = [] |
| else: |
| self.data = {} |
|
|
| def get_values(self, label): |
| return self.data[label] |
|
|
| def summary(self): |
| for label in self.data: |
| avg = self.average(label) |
| print(f"{label} - Average: {avg:.4f}") |
|
|
|
|
| class TrainingManager: |
| def __init__( |
| self, |
| net: nn.Module, |
| dir: str, |
| dataloader, |
| device=device, |
| trainstep_checkin_interval=100, |
| epochs=100, |
| val_dataloader=None, |
| ): |
|
|
| learning_rate = 0.001 |
|
|
| self.clip = 1.0 |
|
|
| self.trainstep_checkin_interval = trainstep_checkin_interval |
| self.epochs = epochs |
|
|
| self.dataloader = dataloader |
| self.val_dataloader = val_dataloader |
|
|
| self.net = net |
| self.net.to(device) |
| self.device = device |
|
|
| self.dir = dir |
|
|
| self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1) |
| self.optimizer = torch.optim.AdamW( |
| self.net.parameters(), lr=learning_rate |
| ) |
|
|
| |
| |
| self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer=self.optimizer, factor=0.9, patience=10 |
| ) |
|
|
| self.tracker = ValueTracker() |
|
|
| self.resume_epoch, self.resume_step = self.get_resume() |
| if self.resume_epoch >= self.epochs - 1: |
| pass |
| elif self.resume_epoch != 0 or self.resume_step != 0: |
| self.resume() |
| else: |
| if os.path.exists(self.dir) and any( |
| os.path.isfile(os.path.join(self.dir, item)) |
| for item in os.listdir(self.dir) |
| ): |
| raise ValueError(f"The directory '{self.dir}' contains files!") |
|
|
| os.makedirs(self.dir, exist_ok=True) |
| os.makedirs(os.path.join(self.dir, "ckpt"), exist_ok=True) |
|
|
| print(f"{self.get_param_count()} parameters.") |
| |
| |
| signal.signal(signal.SIGINT, self._signal_handler) |
| self._interrupted = False |
|
|
| def _signal_handler(self, signum, frame): |
| """Handle keyboard interrupt gracefully""" |
| print("\nKeyboard interrupt received. Saving checkpoint...") |
| self._interrupted = True |
|
|
| def _save_on_interrupt(self, epoch, step): |
| """Save checkpoint and resume info on interrupt""" |
| try: |
| self._save("latest.pt") |
| self.write_resume(epoch, step) |
| print(f"Checkpoint saved at epoch {epoch}, step {step}") |
| except Exception as e: |
| print(f"Failed to save checkpoint: {e}") |
| finally: |
| print("Exiting...") |
| sys.exit(0) |
|
|
| def hasnan(self): |
| for _, param in self.net.named_parameters(): |
| if torch.isnan(param).any(): |
| return True |
| for _, param in self.net.named_parameters(): |
| if param.grad is not None and torch.isnan(param.grad).any(): |
| return True |
|
|
| return False |
|
|
| def _save(self, name="latest.pt"): |
| with open(os.path.join(self.dir, "ckpt", name), "wb+") as f: |
| torch.save(self.net.state_dict(), f) |
|
|
| def _load(self, name="latest.pt"): |
| self.net.load_state_dict( |
| torch.load(os.path.join(self.dir, "ckpt", name), weights_only=True) |
| ) |
|
|
| def write_resume(self, epoch, step=0): |
| with open(os.path.join(self.dir, "ckpt", "resume.txt"), "w+") as f: |
| f.write(f"{epoch},{step}") |
|
|
| def get_resume(self): |
| try: |
| with open(os.path.join(self.dir, "ckpt", "resume.txt"), "r") as f: |
| content = f.read().strip() |
| if ',' in content: |
| epoch, step = content.split(',') |
| return int(epoch), int(step) |
| else: |
| |
| return int(content), 0 |
| except (FileNotFoundError, ValueError): |
| return 0, 0 |
|
|
| def write_best_val_loss(self, loss): |
| with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "w+") as f: |
| f.write(f"{loss:.6f}") |
|
|
| def get_best_val_loss(self): |
| try: |
| with open(os.path.join(self.dir, "ckpt", "best_val_loss.txt"), "r") as f: |
| return float(f.read()) |
| except (FileNotFoundError, ValueError): |
| return float("inf") |
|
|
| def resume(self): |
| self._load("latest.pt") |
|
|
| def save(self, loss): |
| self._save("latest.pt") |
|
|
| best_val_loss = self.get_best_val_loss() |
| if loss < best_val_loss: |
| best_val_loss = loss |
| self._save("best.pt") |
| self.write_best_val_loss(best_val_loss) |
|
|
| |
|
|
| def on_trainloop_checkin(self, epoch, step, dataloader_len): |
| if self.hasnan(): |
| |
| print("RESUMING") |
| self.resume() |
|
|
| self._save("latest.pt") |
| self.write_resume(epoch, step + 1) |
|
|
| log_data( |
| {"Loss/Trainstep": self.tracker.average("Loss/trainstep")}, |
| epoch * dataloader_len + step, |
| ) |
| log_data( |
| {"Acc/Trainstep": self.tracker.average("Acc/trainstep")}, |
| epoch * dataloader_len + step, |
| ) |
| log_data( |
| {"TopKAcc/Trainstep": self.tracker.average("TopKAcc/trainstep")}, |
| epoch * dataloader_len + step, |
| ) |
|
|
| self.tracker.reset("Loss/trainstep") |
| self.tracker.reset("Acc/trainstep") |
| self.tracker.reset("TopKAcc/trainstep") |
|
|
| def on_epoch_checkin(self, epoch): |
| if self.hasnan(): |
| |
| self.resume() |
|
|
| val_loss = float("inf") |
| try: |
| val_loss = self.tracker.average("Loss/val/epoch") |
| except KeyError: |
| pass |
|
|
| self.save( |
| val_loss if val_loss < float("inf") else self.tracker.average("Loss/epoch") |
| ) |
|
|
| log_data( |
| { |
| "Loss/Epoch": self.tracker.average("Loss/epoch"), |
| "Loss/Val/Epoch": val_loss, |
| "Perplexity/Val/Epoch": float(np.exp(val_loss)), |
| "TopKAcc/Epoch": self.tracker.average("TopKAcc/epoch"), |
| }, |
| epoch, |
| ) |
|
|
| self.tracker.reset("Acc/epoch") |
| self.tracker.reset("Loss/epoch") |
| self.tracker.reset("Loss/val/epoch") |
| self.tracker.reset("TopKAcc/epoch") |
| self.tracker.reset("Perplexity/val/epoch") |
|
|
| self.write_resume(epoch + 1, 0) |
|
|
| def eval_model(self, data, compute_metrics=True): |
| if type(data) == tuple or type(data) == list: |
| data = tuple(d.to(self.device) for d in data) |
| batch, attn_mask = data |
| else: |
| data = data.to(self.device) |
| batch = data |
| attn_mask = None |
|
|
| del attn_mask |
|
|
| labels = batch[:, 1:].contiguous() |
| batch = batch[:, :-1].contiguous() |
|
|
| |
| results = self.net(batch, transpose=True) |
| results = results.transpose(0, 1) |
|
|
| |
| loss = self.criterion(results.reshape(-1, results.size(-1)), labels.reshape(-1)) |
|
|
| if not compute_metrics: |
| return loss, None, None |
| |
| |
| preds = results.reshape(-1, results.size(-1)).argmax(dim=1) |
| labels_flat = labels.reshape(-1) |
| acc = (preds == labels_flat).float().mean() |
|
|
| |
| top_k = 5 |
| top_k_preds = results.reshape(-1, results.size(-1)).topk(top_k, dim=1).indices |
| top_k_acc = (top_k_preds == labels_flat.unsqueeze(1)).any(dim=1).float().mean().item() |
|
|
| return loss, acc, top_k_acc |
|
|
| def run_generation(self, data): |
| batch, attn_mask = data |
| start_sequence = batch[:, :-1].contiguous()[0][:100].unsqueeze(0) |
| result = evaluate_topk( |
| self.net, start_sequence, amt=100, k=10, temperature=0.8, device=device |
| ) |
|
|
| result = dataset.manager.decode(result[0]) |
| batch_str = dataset.manager.decode(start_sequence[0]) |
|
|
| result = f"<data>{batch_str}</data>{result[len(batch_str):]}" |
| |
|
|
| with open(os.path.join(self.dir, "ckpt", "generated.txt"), "a+") as f: |
| f.write(f"K=10,T=0.8: {result}\n") |
|
|
| def epoch_gen(self, loader): |
| if loader is not None: |
| for data in loader: |
| self.run_generation(data) |
| break |
|
|
| def trainstep(self, data): |
| self.optimizer.zero_grad() |
|
|
| loss, acc, topk_acc = self.eval_model(data) |
|
|
| self.tracker.add("Loss/trainstep", loss.item()) |
| self.tracker.add("Loss/epoch", loss.item()) |
|
|
| self.tracker.add("Acc/trainstep", acc.item()) |
| self.tracker.add("TopKAcc/trainstep", topk_acc) |
| self.tracker.add("TopKAcc/epoch", topk_acc) |
|
|
| loss.backward() |
| self.optimizer.step() |
|
|
| return loss.detach(), acc.detach() |
|
|
| @torch.no_grad() |
| def valstep(self, data): |
| loss, acc, topk_acc = self.eval_model(data) |
|
|
| self.tracker.add("Loss/valstep", loss.item()) |
| self.tracker.add("Loss/val/epoch", loss.item()) |
|
|
| self.tracker.add("Perplexity/val/epoch", float(np.exp(loss.item()))) |
|
|
| self.tracker.add("TopKAcc/valstep", topk_acc) |
| self.tracker.add("TopKAcc/val/epoch", topk_acc) |
|
|
| return loss.detach(), acc.detach() |
|
|
| def val_loop(self, val_loader): |
| if val_loader is not None: |
| for step, data in enumerate( |
| test_tqdm := tqdm( |
| val_loader, leave=False, dynamic_ncols=True, desc=f"valloop" |
| ) |
| ): |
| self.valstep(data) |
| avg_val_loss = self.tracker.average("Loss/val/epoch") |
| test_tqdm.set_postfix({"Val Loss": f"{avg_val_loss:.3f}"}) |
|
|
| def train_loop(self, dataloader, epoch): |
| start_step = self.resume_step if epoch == self.resume_epoch else 0 |
| |
| for step, data in enumerate( |
| train_tqdm := tqdm( |
| dataloader, leave=False, dynamic_ncols=True, desc=f"trainloop" |
| ) |
| ): |
| |
| if self._interrupted: |
| self._save_on_interrupt(epoch, step) |
| raise KeyboardInterrupt("Training interrupted by user") |
| |
| |
| if step < start_step: |
| continue |
| |
| self.trainstep(data) |
|
|
| avg_train_loss = self.tracker.average("Loss/trainstep") |
| train_tqdm.set_postfix({"Train Loss": f"{avg_train_loss:.3f}"}) |
|
|
| if ( |
| step % self.trainstep_checkin_interval |
| == self.trainstep_checkin_interval - 1 |
| ): |
| |
| self.on_trainloop_checkin(epoch, step, len(dataloader)) |
| |
|
|
| def epoch(self, epoch: int, dataloader, val_loader=None): |
| if self._interrupted: |
| return |
| |
| self.net.train() |
| self.train_loop(dataloader, epoch) |
| |
| if self._interrupted: |
| return |
| |
| tqdm.write(self.get_memory_stats(self.net, dataloader.dataset, sep=" / ")) |
| self.net.eval() |
| self.val_loop(val_loader) |
|
|
| if self._interrupted: |
| return |
| |
| self.epoch_gen(val_loader) |
| self.on_epoch_checkin(epoch) |
|
|
| def train(self, epochs=None, dataloader=None): |
|
|
| if epochs is not None: |
| self.epochs = epochs |
|
|
| if dataloader is not None: |
| self.dataloader = dataloader |
|
|
| try: |
| for e in trange( |
| self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
| ): |
| if self._interrupted: |
| break |
| |
| self.epoch(e, self.dataloader, self.val_dataloader) |
|
|
| except KeyboardInterrupt: |
| print("\nTraining interrupted. Checkpoint saved.") |
| finally: |
| print("Training session ended.") |
| gc.collect() |
| os.system( |
| """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| ) |
|
|
| @staticmethod |
| def get_curriculum_enum(): |
| return Enum( |
| "Curriculum", |
| [ |
| ("NOOP", 1), |
| ("CURRICULUM", 2), |
| ("ANTICURRICULUM", 3), |
| ("SEQUENTIAL", 4), |
| ("HYBRID", 5), |
| ], |
| ) |
|
|
| def train_curriculum( |
| self, epochs=None, dataloader=None, curriculum_type=None, loss_based=False |
| ): |
|
|
| print(f"Training curriculum: {curriculum_type} loss_based: {loss_based}") |
|
|
| Curriculum = self.get_curriculum_enum() |
|
|
| if curriculum_type is None: |
| curriculum_type = Curriculum.NOOP |
|
|
| if epochs is not None: |
| self.epochs = epochs |
|
|
| if dataloader is not None: |
| self.dataloader = dataloader |
|
|
| sorted_indices = sorted( |
| range(len(self.dataloader.dataset)), |
| key=lambda i: self.dataloader.dataset[i][1], |
| reverse=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
| ) |
|
|
| |
| standard_schedule = [ |
| min(1.0, ((i + 2) - (i % 2)) / self.epochs) for i in range(self.epochs) |
| ] |
| hybrid_schedule = [ |
| min(1.0, (i + 2) / self.epochs) for i in range(self.epochs) |
| ] |
| step_size = 1 / (self.epochs / 2) |
|
|
| try: |
| for e in trange( |
| self.resume_epoch, self.epochs, dynamic_ncols=True, unit_scale=True, unit_divisor=60 |
| ): |
|
|
| if loss_based: |
| sorted_indices = self.get_loss_based_indices( |
| self.dataloader, |
| anti=(curriculum_type.value == Curriculum.ANTICURRICULUM.value), |
| ) |
|
|
| subset_indices = None |
| if curriculum_type.value == Curriculum.NOOP.value: |
| print("No curriculum") |
| subset_indices = sorted_indices |
| elif curriculum_type.value == Curriculum.SEQUENTIAL.value: |
| print("Sequential curriculum") |
| subset_indices = sorted_indices[ |
| int( |
| max(len(sorted_indices) * (standard_schedule[e] - step_size), 0) |
| ) : int(len(sorted_indices) * standard_schedule[e]) |
| ] |
| elif curriculum_type.value == Curriculum.HYBRID.value: |
| print("Hybrid curriculum") |
| subset_indices = sorted_indices[ |
| int( |
| max(len(sorted_indices) * (hybrid_schedule[e] - step_size), 0) |
| ) : int(len(sorted_indices) * hybrid_schedule[e]) |
| ] |
| elif curriculum_type.value == Curriculum.CURRICULUM.value: |
| print("Curriculum") |
| subset_indices = sorted_indices[ |
| : int(len(sorted_indices) * standard_schedule[e]) |
| ] |
| elif curriculum_type.value == Curriculum.ANTICURRICULUM.value: |
| print("Anti curriculum") |
| subset_indices = sorted_indices[ |
| : int(len(sorted_indices) * standard_schedule[e]) |
| ] |
| else: |
| raise ValueError(f"Unknown curriculum type: {curriculum_type}") |
|
|
| subset = torch.utils.data.Subset(self.dataloader.dataset, subset_indices) |
| cur_dataloader = torch.utils.data.DataLoader( |
| subset, batch_size=self.dataloader.batch_size, shuffle=True |
| ) |
|
|
| self.epoch(e, cur_dataloader, self.val_dataloader) |
|
|
| except KeyboardInterrupt: |
| print("\nCurriculum training interrupted. Checkpoint saved.") |
| finally: |
| print("Curriculum training session ended.") |
| gc.collect() |
| os.system( |
| """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| ) |
|
|
| print("All done!") |
| gc.collect() |
| os.system( |
| """osascript -e 'display notification "Training complete" with title "Training Complete"'""" |
| ) |
|
|
| def get_loss_based_indices(self, dataloader, anti=False): |
| losses = [] |
| |
| temp_dataloader = torch.utils.data.DataLoader( |
| dataloader.dataset, |
| batch_size=dataloader.batch_size, |
| shuffle=False, |
| num_workers=( |
| dataloader.num_workers if hasattr(dataloader, "num_workers") else 0 |
| ), |
| ) |
|
|
| with torch.no_grad(): |
| for batch, _ in tqdm( |
| temp_dataloader, |
| dynamic_ncols=True, |
| leave=False, |
| desc="Loss-based sorting", |
| ): |
| loss, _, _ = self.eval_model(batch, compute_metrics=False) |
| |
| |
| if isinstance(loss, torch.Tensor) and loss.dim() == 0: |
| losses.extend([loss.item()] * batch.size(0)) |
| else: |
| |
| losses.extend(loss.detach().cpu().tolist()) |
|
|
| sorted_indices = sorted( |
| range(len(dataloader.dataset)), key=lambda i: losses[i], reverse=anti |
| ) |
| return sorted_indices |
|
|
| def nan_debug(self): |
| torch.autograd.set_detect_anomaly(True) |
|
|
| def forward_hook(module, input, output): |
| if isinstance(output, tuple): |
| return |
| if torch.isnan(output).any() or torch.isinf(output).any(): |
| print(f"NaNs/Infs detected in {module}") |
|
|
| for module in self.net.modules(): |
| module.register_forward_hook(forward_hook) |
| self.val_loop(self.val_dataloader) |
|
|
| def get_param_count(self): |
| return sum(p.numel() for p in self.net.parameters()) |
|
|
| def profile_trainstep(self): |
|
|
| self.net.train() |
| data = next(iter(self.dataloader)) |
|
|
| |
| with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: |
| with record_function("train_step"): |
| self.trainstep(data) |
|
|
| print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) |
|
|
| @staticmethod |
| def get_memory_stats(net, trainset, sep="\n"): |
| result = "" |
| import datetime |
| import time |
| result += f"Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + sep |
| import psutil |
| if torch.backends.mps.is_available(): |
| result += f"MPS: {torch.mps.current_allocated_memory()/1e9:.2f} GB" + sep |
| result += f"RAM: {psutil.virtual_memory().percent}% used" + sep |
| |
| |
| chunks = getattr(trainset, 'chunks', getattr(trainset.dataset, 'chunks', None)) |
| |
| if chunks is not None: |
| result += f"data: {sum(p.numel() * p.element_size() for p in [chunks]) / 1e9:.2f} GB" + sep |
| |
| |
| model_size = sum(p.numel() * p.element_size() for p in net.parameters()) / 1e9 |
| result += f"Params: {model_size:.2f} GB" + sep |
| |
| |
| optimizer_size = model_size * 2 |
| result += f"Optim (est): {optimizer_size:.2f} GB" + sep |
|
|
| return result |
|
|