| import sys |
| import torch |
| import torch.nn as nn |
|
|
|
|
| def _print(s): |
| print(s) |
| sys.stdout.flush() |
|
|
|
|
| def get_latents(model, tokenizer, sequence, device): |
| tokens = tokenizer(sequence, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**tokens) |
| embeds = outputs.hidden_states[-1].squeeze(0) |
| return embeds |
|
|
|
|
|
|
| |
| def freeze_model(model: nn.Module): |
| |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
|
|
|
|
| |
| def apply_gptj_freezing(model, N_layers): |
| def unfreeze_n_layers(model, N_layers): |
| |
| model_layers = len(model.transformer.h) |
| for i, h in enumerate(model.transformer.h): |
| if i >= model_layers - N_layers: |
| for module in h.attn.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
|
|
| def check_frozen_model(model, N_layers: int): |
| """ |
| Verify that only the last N_layers of model.transformer.h are unfrozen. |
| Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py |
| """ |
| model_layers = len(model.transformer.h) |
| frozen_layers = 0 |
| unfrozen_layers = 0 |
| for i, h in enumerate(model.transformer.h): |
| if i >= model_layers - N_layers: |
| if any(param.requires_grad for param in h.parameters()): |
| unfrozen_layers += 1 |
| else: |
| print(f"Layer {i} has all parameters frozen, but it should be unfrozen.") |
| else: |
| if any(param.requires_grad for param in h.parameters()): |
| print(f"Layer {i} is not frozen, but it should be frozen.") |
| else: |
| frozen_layers += 1 |
|
|
| assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \ |
| f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}" |
|
|
| print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}") |
|
|
| freeze_model(model) |
| unfreeze_n_layers(model, N_layers) |
| check_frozen_model(model, N_layers) |
|
|
|
|
|
|
|
|
|
|
| |
| def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str): |
| """ |
| Freeze all layers except last N for esm-like architectures |
| |
| Args: |
| model (nn.Module): model to freeze |
| N_layers (int): num encoder layers to unfreeze |
| model_type (str): one of {"esm", "evoflow", "dplm"} |
| """ |
|
|
| |
| if model_type == "dplm": |
| encoder_layers = model.net.esm.encoder.layer |
| elif model_type in ("esm", "evoflow"): |
| encoder_layers = model.esm.encoder.layer |
| else: |
| raise ValueError(f"Unknown model_type: {model_type}") |
|
|
| def unfreeze_n_layers(layers, N_layers: int): |
| model_layers = len(layers) |
| for i, layer in enumerate(layers): |
| if i >= model_layers - N_layers: |
| for module in layer.attention.self.key.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
| for module in layer.attention.self.query.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
| for module in layer.attention.self.value.modules(): |
| for param in module.parameters(): |
| param.requires_grad = True |
|
|
| def check_model(layers, N_layers: int): |
| model_layers = len(layers) |
| frozen_layers = 0 |
| unfrozen_layers = 0 |
|
|
| for i, layer in enumerate(layers): |
| if i >= model_layers - N_layers: |
| layer_frozen = True |
| for module in layer.attention.self.key.modules(): |
| if any(param.requires_grad for param in module.parameters()): |
| layer_frozen = False |
| for module in layer.attention.self.query.modules(): |
| if any(param.requires_grad for param in module.parameters()): |
| layer_frozen = False |
| for module in layer.attention.self.value.modules(): |
| if any(param.requires_grad for param in module.parameters()): |
| layer_frozen = False |
| |
| if layer_frozen: |
| print(f"layer {i} has all parameters frozen, but it should be unfrozen.") |
| else: |
| unfrozen_layers += 1 |
| else: |
| if any(param.requires_grad for param in layer.parameters()): |
| print(f"layer {i} is not frozen, but it should") |
| else: |
| frozen_layers += 1 |
|
|
| assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \ |
| f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}" |
|
|
|
|
| freeze_model(model) |
| unfreeze_n_layers(encoder_layers, N_layers) |
| check_model(encoder_layers, N_layers) |
|
|