| import torch |
| import math |
|
|
| from torch.optim import Optimizer |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.optim.adamw import adamw |
|
|
| try: |
| import deepspeed |
| from deepspeed.ops.adam import FusedAdam |
| from deepspeed.ops.adam import DeepSpeedCPUAdam |
| except: |
| pass |
|
|
|
|
| def get_optimizer(cfg, params): |
| if cfg.optim.type == 'adam': |
| return torch.optim.Adam( |
| params=params, |
| lr=cfg.optim.lr, |
| weight_decay=cfg.optim.weight_decay, |
| betas=(cfg.optim.beta1, cfg.optim.beta2) |
| ) |
| elif cfg.optim.type == 'adamw': |
| return AdamW( |
| params=params, |
| lr=cfg.optim.lr, |
| weight_decay=cfg.optim.weight_decay, |
| betas=(cfg.optim.beta1, cfg.optim.beta2) |
| ) |
| elif cfg.type == 'fusedadam': |
| return FusedAdam( |
| params=params, |
| lr=cfg.lr, |
| weight_decay=cfg.weight_decay, |
| betas=cfg.betas, |
| ) |
| else: |
| raise NotImplementedError('Optimizer not supported: %s' % cfg.type) |
|
|
|
|
| class AdamW(torch.optim.AdamW): |
| @torch.no_grad() |
| def step(self, closure=None): |
| """Performs a single optimization step. |
| |
| Args: |
| closure (callable, optional): A closure that reevaluates the model |
| and returns the loss. |
| """ |
| self._cuda_graph_capture_health_check() |
|
|
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| for group in self.param_groups: |
| params_with_grad = [] |
| grads = [] |
| exp_avgs = [] |
| exp_avg_sqs = [] |
| max_exp_avg_sqs = [] |
| state_steps = [] |
| amsgrad = group['amsgrad'] |
| beta1, beta2 = group['betas'] |
|
|
| for p in group['params']: |
| if p.grad is None: |
| continue |
| params_with_grad.append(p) |
| if p.grad.is_sparse: |
| raise RuntimeError('AdamW does not support sparse gradients') |
| grads.append(p.grad) |
|
|
| state = self.state[p] |
|
|
| |
| if len(state) == 0: |
| state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ |
| if self.defaults['capturable'] else torch.tensor(0.) |
| |
| state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| |
| state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| if amsgrad: |
| |
| state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
| exp_avgs.append(state['exp_avg']) |
| exp_avg_sqs.append(state['exp_avg_sq']) |
|
|
| if amsgrad: |
| max_exp_avg_sqs.append(state['max_exp_avg_sq']) |
|
|
| state_steps.append(state['step'].cpu()) |
|
|
| adamw(params_with_grad, |
| grads, |
| exp_avgs, |
| exp_avg_sqs, |
| max_exp_avg_sqs, |
| state_steps, |
| amsgrad=amsgrad, |
| beta1=beta1, |
| beta2=beta2, |
| lr=group['lr'], |
| weight_decay=group['weight_decay'], |
| eps=group['eps'], |
| maximize=group['maximize'], |
| foreach=group['foreach'], |
| capturable=group['capturable']) |
|
|
| return loss |
|
|
| def get_scheduler(cfg, optimizer): |
| if cfg.optim.scheduler is None: |
| return BlackHole() |
| elif cfg.optim.scheduler == 'plateau': |
| return ( |
| torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, |
| mode=cfg.mode, |
| factor=cfg.factor, |
| patience=cfg.patience, |
| min_lr=cfg.min_lr, |
| ), |
| {'monitor': "val/loss", 'interval': 'epoch'} |
| ) |
| elif cfg.optim.scheduler == 'noam': |
| return ( |
| NoamScheduler( |
| optimizer, |
| lr=cfg.lr, |
| warmup_steps=cfg.warmup_steps, |
| model_size=cfg.model_size, |
| warmup_init_lr=cfg.get('warmup_init_lr') |
| ), |
| {'frequency': 1, 'interval': 'step'} |
| ) |
| elif cfg.optim.scheduler == 'polynomial': |
| return ( |
| PolyNomialLRScheduler( |
| optimizer, |
| total_steps=cfg.training.max_steps, |
| warmup_steps=cfg.training.warmup_steps, |
| lr=cfg.optim.lr, |
| lr_end=cfg.optim.lr_end, |
| warmup_init_lr=cfg.optim.warmup_init_lr, |
| power=cfg.optim.power |
| ), |
| {'frequency': 1, 'interval': 'step'} |
| ) |
| elif cfg.optim.scheduler == 'multistep': |
| return torch.optim.lr_scheduler.MultiStepLR( |
| optimizer, |
| milestones=cfg.milestones, |
| gamma=cfg.gamma, |
| ) |
| elif cfg.optim.scheduler == 'exp': |
| return torch.optim.lr_scheduler.ExponentialLR( |
| optimizer, |
| gamma=cfg.gamma, |
| ) |
| elif cfg.optim.scheduler == 'progen_ft': |
| sched = CosineToFrac( |
| optimizer=optimizer, |
| total_steps=cfg.training.max_steps, |
| final_frac=0.2, |
| ) |
| return (sched, {'frequency': 1, 'interval': 'step'}) |
| elif cfg.optim.scheduler is None: |
| return BlackHole() |
| else: |
| raise NotImplementedError('Scheduler not supported: %s' % cfg.optim.scheduler) |
|
|
|
|
| class BlackHole(object): |
| def __setattr__(self, name, value): |
| pass |
|
|
| def __call__(self, *args, **kwargs): |
| return self |
|
|
| def __getattr__(self, name): |
| return self |
|
|
|
|
| |
| def polynomial_lr_schedule(step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power): |
| if step < warmup_steps: |
| return warmup_init_lr + (lr - warmup_init_lr) * step / warmup_steps |
| elif step > total_steps: |
| return lr_end |
| else: |
| return lr_end + (lr - lr_end) * (1 - (step - warmup_steps) / (total_steps - warmup_steps)) ** power |
|
|
| class PolyNomialLRScheduler(LambdaLR): |
| def __init__( |
| self, |
| optimizer: Optimizer, |
| total_steps: int = 1000, |
| warmup_steps: int = 0, |
| lr: float = 0.00004, |
| lr_end: float = 1e-5, |
| warmup_init_lr: float = 1e-07, |
| power: float = 1.0, |
| ) -> None: |
|
|
| self.warmup_init_lr = warmup_init_lr |
| self.warmup_steps = warmup_steps |
|
|
| def lr_lambda(step): |
| return polynomial_lr_schedule( |
| step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power |
| ) / lr |
|
|
| super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
| |
| def cosine_frac_scheduler(step, total_steps, final_frac): |
| s = min(max(step, 0), total_steps) |
| cos = 0.5 * (1.0 + math.cos(math.pi * s / total_steps)) |
| return final_frac + (1.0 - final_frac) * cos |
|
|
| class CosineToFrac(LambdaLR): |
| """ |
| Cosine decay of the LR multiplier from 1.0 -> final_frac over total_steps (no warmup). |
| For ProGen fine-tuning, final_frac=0.2 implements decay to lr/5. |
| """ |
| def __init__(self, optimizer, total_steps, final_frac=0.2): |
| self.total_steps = max(int(total_steps), 1) |
| self.final_frac = float(final_frac) |
|
|
| def lr_lambda(step): |
| return cosine_frac_scheduler( |
| step=step, |
| total_steps=self.total_steps, |
| final_frac=self.final_frac |
| ) |
|
|
| super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
|
| def inverse_sqrt_lr_schedule(step, warmup_steps, warmup_init_lr, lr_step, decay_step): |
| if step == 0: |
| step = 1 |
| if step < warmup_steps: |
| return warmup_init_lr + lr_step * step |
| else: |
| return decay_step * step ** -0.5 |
|
|
|
|
| class InverseSqrtLRScheduler(LambdaLR): |
| def __init__( |
| self, |
| optimizer: Optimizer, |
| warmup_steps: int = 0, |
| lr: float = 5e-04, |
| warmup_init_lr: float = 1e-07, |
| ) -> None: |
|
|
| self.warmup_init_lr = warmup_init_lr |
| self.warmup_steps = warmup_steps |
| self.lr_step = (lr - warmup_init_lr) / warmup_steps |
| self.decay_step = lr * warmup_steps ** 0.5 |
|
|
| def lr_lambda(step): |
| return inverse_sqrt_lr_schedule( |
| step, warmup_steps, warmup_init_lr, self.lr_step, self.decay_step |
| ) / lr |
|
|
| super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|
| def noam_lr_schedule(step, warmup_steps, factor, model_size): |
| if step == 0: |
| step = 1 |
| return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))) |
|
|
|
|
| class NoamScheduler(LambdaLR): |
| def __init__( |
| self, |
| optimizer: Optimizer, |
| lr, |
| warmup_init_lr, |
| model_size: int = 128, |
| warmup_steps: int = 0, |
| factor: int = 2, |
| ) -> None: |
|
|
| |
| def lr_lambda(step): |
| return noam_lr_schedule(step, warmup_steps, factor, model_size) / lr |
|
|
| super().__init__(optimizer, lr_lambda=lr_lambda) |
|
|
|
|