ESMplusplus_small / modeling_esm_plusplus.py
lhallee's picture
Upload modeling_esm_plusplus.py with huggingface_hub
42efb78 verified
### Embedding Mixin + Pooler
import os
import sqlite3
import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Callable, List, Optional
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from transformers import PreTrainedTokenizerBase
class Pooler:
def __init__(self, pooling_types: List[str]):
self.pooling_types = pooling_types
self.pooling_options = {
'mean': self.mean_pooling,
'max': self.max_pooling,
'norm': self.norm_pooling,
'median': self.median_pooling,
'std': self.std_pooling,
'var': self.var_pooling,
'cls': self.cls_pooling,
'parti': self._pool_parti,
}
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
maxed_attentions = torch.max(attentions, dim=1)[0]
return maxed_attentions
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
# Run PageRank on the attention matrix converted to a graph.
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
# Returns the PageRank scores for each token node.
G = self._convert_to_graph(attention_matrix)
if G.number_of_nodes() != attention_matrix.shape[0]:
raise Exception(
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
if G.number_of_edges() == 0:
raise Exception(f"You don't seem to have any attention edges left in the graph.")
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
def _convert_to_graph(self, matrix):
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
# Each element in the matrix represents a directed edge with a weight.
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
return G
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
# Remove keys where attention_mask is 0
if attention_mask is not None:
for k in list(dict_importance.keys()):
if attention_mask[k] == 0:
del dict_importance[k]
#dict_importance[0] # remove cls
#dict_importance[-1] # remove eos
total = sum(dict_importance.values())
return np.array([v / total for _, v in dict_importance.items()])
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
# emb is (b, L, d), maxed_attentions is (b, L, L)
emb_pooled = []
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
dict_importance = self._page_rank(a)
importance_weights = self._calculate_importance_weights(dict_importance, mask)
num_tokens = int(mask.sum().item())
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
pooled = torch.tensor(np.array(emb_pooled))
return pooled
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.mean(dim=1)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.max(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).max(dim=1).values
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.norm(dim=1, p=2)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).norm(dim=1, p=2)
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.median(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).median(dim=1).values
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.std(dim=1)
else:
# Compute variance correctly over non-masked positions, then take sqrt
var = self.var_pooling(emb, attention_mask, **kwargs)
return torch.sqrt(var)
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.var(dim=1)
else:
# Correctly compute variance over only non-masked positions
attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
# Compute mean over non-masked positions
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
mean = mean.unsqueeze(1) # (b, 1, d)
# Compute squared differences from mean, only over non-masked positions
squared_diff = (emb - mean) ** 2 # (b, L, d)
# Sum squared differences over non-masked positions and divide by count
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
return var
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
return emb[:, 0, :]
def __call__(
self,
emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attentions: Optional[torch.Tensor] = None
): # [mean, max]
final_emb = []
for pooling_type in self.pooling_types:
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
class ProteinDataset(TorchDataset):
"""Simple dataset for protein sequences."""
def __init__(self, sequences: list[str]):
self.sequences = sequences
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> str:
return self.sequences[idx]
def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]:
def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]:
return tokenizer(sequences, return_tensors="pt", padding='longest')
return _collate_fn
class EmbeddingMixin:
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return next(self.parameters()).device
def _read_sequences_from_db(self, db_path: str) -> set[str]:
"""Read sequences from SQLite database."""
sequences = []
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
c.execute("SELECT sequence FROM embeddings")
while True:
row = c.fetchone()
if row is None:
break
sequences.append(row[0])
return set(sequences)
def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
cursor = conn.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS embeddings ("
"sequence TEXT PRIMARY KEY, "
"embedding BLOB NOT NULL, "
"shape TEXT, "
"dtype TEXT"
")"
)
cursor.execute("PRAGMA table_info(embeddings)")
rows = cursor.fetchall()
column_names = [row[1] for row in rows]
if "shape" not in column_names:
cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
if "dtype" not in column_names:
cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
conn.commit()
def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]:
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
payload = torch.load(save_path, map_location="cpu", weights_only=True)
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
for sequence, tensor in payload.items():
assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
return payload
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]:
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
loaded: dict[str, torch.Tensor] = {}
with sqlite3.connect(db_path) as conn:
self._ensure_embeddings_table(conn)
cursor = conn.cursor()
if sequences is None:
cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
else:
if len(sequences) == 0:
return loaded
placeholders = ",".join(["?"] * len(sequences))
cursor.execute(
f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
tuple(sequences),
)
rows = cursor.fetchall()
for row in rows:
sequence = row[0]
embedding_bytes = row[1]
shape_text = row[2]
dtype_text = row[3]
assert shape_text is not None, "Missing shape metadata in embeddings table."
assert dtype_text is not None, "Missing dtype metadata in embeddings table."
shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
expected_size = int(np.prod(shape_values))
np_dtype = np.dtype(dtype_text)
array = np.frombuffer(embedding_bytes, dtype=np_dtype)
assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
reshaped = array.copy().reshape(tuple(shape_values))
loaded[sequence] = torch.from_numpy(reshaped)
return loaded
def embed_dataset(
self,
sequences: List[str],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
batch_size: int = 2,
max_len: int = 512,
truncate: bool = True,
full_embeddings: bool = False,
embed_dtype: torch.dtype = torch.float32,
pooling_types: List[str] = ['mean'],
num_workers: int = 0,
sql: bool = False,
save: bool = True,
sql_db_path: str = 'embeddings.db',
save_path: str = 'embeddings.pth',
**kwargs,
) -> Optional[dict[str, torch.Tensor]]:
"""
Embed a dataset of protein sequences.
Supports two modes:
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
"""
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
sequences = sorted(sequences, key=len, reverse=True)
hidden_size = self.config.hidden_size
pooler = Pooler(pooling_types) if not full_embeddings else None
tokenizer_mode = tokenizer is not None
if tokenizer_mode:
collate_fn = build_collator(tokenizer)
device = self.device
else:
collate_fn = None
device = None
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if full_embeddings or residue_embeddings.ndim == 2:
return residue_embeddings
return pooler(residue_embeddings, attention_mask)
def iter_batches(to_embed: List[str]):
if tokenizer_mode:
assert collate_fn is not None
assert device is not None
dataset = ProteinDataset(to_embed)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
residue_embeddings = self._embed(input_ids, attention_mask)
yield seqs, residue_embeddings, attention_mask
else:
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
seqs = to_embed[batch_start:batch_start + batch_size]
batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
residue_embeddings, attention_mask = batch_output
assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
yield seqs, residue_embeddings, attention_mask
if sql:
conn = sqlite3.connect(sql_db_path)
self._ensure_embeddings_table(conn)
c = conn.cursor()
already_embedded = self._read_sequences_from_db(sql_db_path)
to_embed = [seq for seq in sequences if seq not in already_embedded]
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
with torch.no_grad():
for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
if full_embeddings:
emb = emb[mask.bool()].reshape(-1, hidden_size)
emb_np = emb.cpu().numpy()
emb_shape = ",".join([str(dim) for dim in emb_np.shape])
emb_dtype = str(emb_np.dtype)
c.execute(
"INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
(seq, emb_np.tobytes(), emb_shape, emb_dtype),
)
if tokenizer_mode and (i + 1) % 100 == 0:
conn.commit()
conn.commit()
conn.close()
return None
embeddings_dict = {}
if os.path.exists(save_path):
embeddings_dict = self.load_embeddings_from_pth(save_path)
to_embed = [seq for seq in sequences if seq not in embeddings_dict]
print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
print(f"Embedding {len(to_embed)} new sequences")
else:
to_embed = sequences
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
with torch.no_grad():
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
if full_embeddings:
emb = emb[mask.bool()].reshape(-1, hidden_size)
embeddings_dict[seq] = emb.cpu()
if save:
torch.save(embeddings_dict, save_path)
return embeddings_dict
"""
ESM++ model implementation.
ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
The ESM Python package is not required
Modified from https://github.com/evolutionaryscale/esm
License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
"""
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from functools import cache, partial
from pathlib import Path
from typing import Optional, Tuple, Union, List
from einops import rearrange, repeat
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
from transformers.modeling_outputs import ModelOutput
try:
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
except ImportError:
create_block_mask = None
flex_attention = None
BlockMask = None
from enum import Enum
### Kernels Flash Attention Detection
def _infer_kernels_flash_variant(kernel) -> str | None:
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
return "flash_attn2"
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
return "flash_attn3"
return None
def _try_get_kernels_flash():
try:
from kernels import get_kernel
except ImportError:
return None, None
flash_kernel = None
flash_kernel_variant = None
try:
flash_kernel = get_kernel("kernels-community/flash-attn3")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
except Exception:
try:
flash_kernel = get_kernel("kernels-community/flash-attn2")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
except Exception:
flash_kernel = None
flash_kernel_variant = None
return flash_kernel, flash_kernel_variant
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
def _kernels_flash_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
except TypeError:
output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
def _kernels_flash_varlen_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_in_batch_q: int,
max_seqlen_in_batch_k: int,
causal: bool = False,
) -> torch.Tensor:
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.varlen_fwd(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
is_causal=causal,
)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_varlen_func(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
causal=causal,
)
except TypeError:
output = FLASH_KERNEL.flash_attn_varlen_func(
query_states, key_states, value_states,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
0.0, None, causal,
)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
### Unpad / Pad helpers for varlen flash attention
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices) -> torch.Tensor:
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
)
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
output[indices] = values
return output
@staticmethod
def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
(indices,) = ctx.saved_tensors
return grad_output[indices], None, None
index_first_axis = IndexFirstAxis.apply
index_put_first_axis = IndexPutFirstAxis.apply
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def _unpad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask_2d: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
batch_size, seq_len, num_heads, head_dim = query_layer.shape
seqlens = attention_mask_2d.sum(dim=1).int()
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
max_seqlen = int(seqlens.max().item())
indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
def kernels_flash_attention_func(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask_2d: torch.Tensor | None = None,
causal: bool = False,
) -> torch.Tensor:
"""Top-level flash attention entry point. Routes to varlen when a 2D padding mask is provided."""
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if not causal and attention_mask_2d is not None:
batch_size, q_len = query_states.shape[:2]
(
query_states, key_states, value_states,
indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
attn_output_unpad = _kernels_flash_varlen_forward(
query_states=query_states, key_states=key_states, value_states=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
)
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
else:
return _kernels_flash_forward(
query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
)
### Attention Backend Enum & Resolution
class AttentionBackend(Enum):
AUTO = "auto"
KERNELS_FLASH = "kernels_flash"
FLEX = "flex"
SDPA = "sdpa"
VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
assert requested_backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
if requested_backend == AttentionBackend.AUTO.value:
if FLASH_KERNEL is not None:
return AttentionBackend.KERNELS_FLASH
if flex_attention is not None:
return AttentionBackend.FLEX
return AttentionBackend.SDPA
if requested_backend == AttentionBackend.KERNELS_FLASH.value:
assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
return AttentionBackend.KERNELS_FLASH
if requested_backend == AttentionBackend.FLEX.value:
assert flex_attention is not None, "Flex Attention is not available in this environment."
return AttentionBackend.FLEX
if requested_backend == AttentionBackend.SDPA.value:
return AttentionBackend.SDPA
raise AssertionError(f"Unsupported attention backend: {requested_backend}")
def get_attention_mask(
effective_backend: AttentionBackend,
batch_size: int,
seq_len: int,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
"""Build padding masks once for all transformer layers.
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
"""
if attention_mask is None:
return None, None, None
attention_mask_2d = attention_mask.bool()
if effective_backend == AttentionBackend.KERNELS_FLASH:
return attention_mask_2d, None, None
if effective_backend == AttentionBackend.FLEX:
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
valid_lens = attention_mask_2d.sum(dim=-1)
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
return attention_mask_2d, None, flex_block_mask
# SDPA / manual
attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
return attention_mask_2d, attention_mask_4d, None
class ESMplusplusConfig(PretrainedConfig):
"""Configuration class for ESM++ model.
Args:
vocab_size: Size of the vocabulary
hidden_size: Dimension of hidden layers
num_attention_heads: Number of attention heads
num_hidden_layers: Number of transformer layers
num_labels: Number of output labels for classification
problem_type: Type of problem - regression, single/multi label classification
"""
model_type = "ESMplusplus"
def __init__(
self,
vocab_size: int = 64,
hidden_size: int = 960,
num_attention_heads: int = 15,
num_hidden_layers: int = 30,
num_labels: int = 2,
problem_type: str | None = None,
dropout: float = 0.0,
initializer_range: float = 0.02,
attn_backend: str = "auto",
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_labels = num_labels
self.problem_type = problem_type
self.dropout = dropout
self.initializer_range = initializer_range
self.tie_word_embeddings = False
self.attn_backend = attn_backend
### Rotary Embeddings
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False,
_inplace: bool = False,
) -> torch.Tensor:
"""Apply rotary embeddings to input based on cos and sin."""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
seqlen = x.size(1)
cos = cos[:seqlen]
sin = sin[:seqlen]
cos = repeat(cos, "s d -> s 1 (2 d)")
sin = repeat(sin, "s d -> s 1 (2 d)")
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
class RotaryEmbedding(torch.nn.Module):
"""Rotary position embeddings.
Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
Args:
dim: Dimension of the embedding
base: Base for computing angular frequencies
interleaved: Whether to use interleaved rotations
scale_base: Base for scaling
scaling_factor: Factor for scaling positions
pos_idx_in_fp32: Whether to compute position indices in fp32
device: Computation device
"""
def __init__(
self,
dim: int,
base: float = 10000.0,
interleaved: bool = False,
scale_base: Optional[float] = None,
scaling_factor: float = 1.0,
pos_idx_in_fp32: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
self.interleaved = interleaved
self.scale_base = scale_base
self.scaling_factor = scaling_factor
self.device = device
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self._inv_freq_compute_device: Optional[torch.device] = None
self.reset_parameters()
def reset_parameters(self):
"""Reset the parameters of the embedding."""
if "inv_freq" in self._buffers and isinstance(self._buffers["inv_freq"], torch.Tensor):
buffer_device = self._buffers["inv_freq"].device
else:
buffer_device = self.device
inv_freq = self._compute_inv_freq(buffer_device)
self._inv_freq_compute_device = inv_freq.device
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.register_buffer("inv_freq", inv_freq, persistent=False)
arange = torch.arange(0, self.dim, 2, device=buffer_device, dtype=torch.float32)
scale = (
(arange + 0.4 * self.dim) / (1.4 * self.dim)
if self.scale_base is not None
else None
)
self.register_buffer("scale", scale)
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
"""Compute inverse frequency bands."""
return 1 / (
self.base
** (
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
/ self.dim
)
)
def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
"""Update the cached cosine and sine values."""
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
t /= self.scaling_factor
if self.inv_freq.dtype != torch.float32:
inv_freq = self.inv_freq.to(torch.float32)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
t /= self.scaling_factor
inv_freq = self.inv_freq
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to queries and keys.
Args:
q: Query tensor of shape (batch, seqlen, nheads, headdim)
k: Key tensor of shape (batch, seqlen, nheads, headdim)
Returns:
Tuple of rotated query and key tensors
"""
assert self._inv_freq_compute_device is not None, "Rotary inv_freq compute device should be set after initialization."
if self._inv_freq_compute_device != q.device:
self.reset_parameters()
self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
assert self._cos_cached is not None
assert self._sin_cached is not None
if self.scale is None:
return (
apply_rotary_emb_torch(
q,
self._cos_cached,
self._sin_cached,
self.interleaved,
True, # inplace=True
),
apply_rotary_emb_torch(
k,
self._cos_cached,
self._sin_cached,
self.interleaved,
True, # inplace=True
),
) # type: ignore
else:
assert False
### Feedforward Network Components
def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
"""Compute corrected dimension for SwiGLU."""
return int(((expansion_ratio * d_model) + 255) // 256 * 256)
class SwiGLU(nn.Module):
"""SwiGLU activation function."""
def __init__(self):
super(SwiGLU, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return F.silu(x1) * x2
def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
"""Create SwiGLU feedforward network with layer normalization."""
return nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
),
SwiGLU(),
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
)
### Attention
class MultiHeadAttention(nn.Module):
"""Multi-head attention with rotary embeddings and configurable backend.
Args:
d_model: Model dimension
n_heads: Number of attention heads
attn_backend: One of "auto", "kernels_flash", "flex", "sdpa"
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_backend: str = "auto",
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = self.d_model // self.n_heads
self.scale = 1.0 / math.sqrt(self.d_head)
self.attn_backend = resolve_attention_backend(attn_backend)
self.layernorm_qkv = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.q_ln = nn.LayerNorm(d_model, bias=False)
self.k_ln = nn.LayerNorm(d_model, bias=False)
self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
self.rotary = RotaryEmbedding(d_model // n_heads)
def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q = q.unflatten(-1, (self.n_heads, self.d_head))
k = k.unflatten(-1, (self.n_heads, self.d_head))
q, k = self.rotary(q, k)
q = q.flatten(-2, -1)
k = k.flatten(-2, -1)
return q, k
def forward(
self,
x: torch.Tensor,
attention_mask_2d: torch.Tensor | None = None,
attention_mask_4d: torch.Tensor | None = None,
flex_block_mask: "BlockMask | None" = None,
output_attentions: bool = False,
output_s_max: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
qkv_BLD3 = self.layernorm_qkv(x)
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
query_BLD, key_BLD = (
self.q_ln(query_BLD).to(query_BLD.dtype),
self.k_ln(key_BLD).to(query_BLD.dtype),
)
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
attn_output, attn_weights, s_max = self._attn(
query_BHLD, key_BHLD, value_BHLD,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
output = self.out_proj(attn_output)
return output, attn_weights, s_max
def _attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_2d: torch.Tensor | None = None,
attention_mask_4d: torch.Tensor | None = None,
flex_block_mask: "BlockMask | None" = None,
output_attentions: bool = False,
output_s_max: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
if output_attentions:
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
if self.attn_backend == AttentionBackend.KERNELS_FLASH:
attn_output, attn_weights = self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d)
elif self.attn_backend == AttentionBackend.FLEX:
attn_output, attn_weights = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask)
elif self.attn_backend == AttentionBackend.SDPA:
attn_output, attn_weights = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
else:
raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}")
s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None
return attn_output, attn_weights, s_max
@torch.no_grad()
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]:
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale
return [s_max_bound[h] for h in range(self.n_heads)]
def _manual_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_4d: torch.Tensor | None = None,
output_s_max: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale
if attention_mask_4d is not None:
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
attn_weights = F.softmax(attn_weights, dim=-1)
context_BHLD = torch.matmul(attn_weights, value_BHLD)
attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)")
s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None
return attn_output, attn_weights, s_max
def _kernels_flash_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_2d: torch.Tensor | None = None,
) -> tuple[torch.Tensor, None]:
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
attn_output = kernels_flash_attention_func(
query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
attention_mask_2d=attention_mask_2d, causal=False,
)
return rearrange(attn_output, "b s h d -> b s (h d)"), None
def _flex_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
flex_block_mask: "BlockMask | None" = None,
) -> tuple[torch.Tensor, None]:
assert flex_attention is not None, "Flex attention is not available in this environment."
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
)
context_BHLD = flex_attention(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
def _sdpa_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
attention_mask_4d: torch.Tensor | None = None,
) -> tuple[torch.Tensor, None]:
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale,
)
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
### Regression Head
def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
"""Create a regression head with optional hidden dimension.
Args:
d_model: Input dimension
output_dim: Output dimension
hidden_dim: Optional hidden dimension (defaults to d_model)
"""
hidden_dim = hidden_dim if hidden_dim is not None else d_model
return nn.Sequential(
nn.Linear(d_model, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim),
)
### Transformer Block
class UnifiedTransformerBlock(nn.Module):
"""Transformer block with attention and feedforward layers."""
def __init__(
self,
d_model: int,
n_heads: int,
residue_scaling_factor: float = 1,
expansion_ratio: float = 8 / 3,
dropout: float = 0.0,
attn_backend: str = "auto",
):
super().__init__()
self.attn = MultiHeadAttention(d_model=d_model, n_heads=n_heads, attn_backend=attn_backend)
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
self.scaling_factor = residue_scaling_factor
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
attention_mask_2d: torch.Tensor | None = None,
attention_mask_4d: torch.Tensor | None = None,
flex_block_mask: "BlockMask | None" = None,
output_attentions: bool = False,
output_s_max: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]:
attn_output, attn_weights, s_max = self.attn(
x,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
x = x + self.dropout(attn_output) / self.scaling_factor
x = x + self.dropout(self.ffn(x)) / self.scaling_factor
return x, attn_weights, s_max
### Model Outputs
@dataclass
class TransformerOutput(ModelOutput):
"""Output type for transformer encoder."""
last_hidden_state: Optional[torch.Tensor] = None
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
@dataclass
class ESMplusplusOutput(ModelOutput):
"""Output type for ESM++ models."""
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
last_hidden_state: Optional[torch.Tensor] = None
hidden_states: Optional[Tuple[torch.Tensor]] = None
attentions: Optional[Tuple[torch.Tensor]] = None
s_max: Optional[Tuple[list[torch.Tensor], ...]] = None
### Transformer Stack
class TransformerStack(nn.Module):
"""Stack of transformer blocks."""
def __init__(
self,
d_model: int,
n_heads: int,
n_layers: int,
dropout: float = 0.0,
attn_backend: str = "auto",
):
super().__init__()
self.attention_backend = resolve_attention_backend(attn_backend)
self.blocks = nn.ModuleList(
[
UnifiedTransformerBlock(
d_model,
n_heads,
residue_scaling_factor=math.sqrt(n_layers / 36),
dropout=dropout,
attn_backend=attn_backend,
)
for i in range(n_layers)
]
)
self.norm = nn.LayerNorm(d_model, bias=False)
self.gradient_checkpointing = False
@property
def attn_backend(self) -> AttentionBackend:
return self.attention_backend
@attn_backend.setter
def attn_backend(self, backend: str) -> None:
resolved = resolve_attention_backend(backend)
self.attention_backend = resolved
for block in self.blocks:
block.attn.attn_backend = resolved
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_s_max: Optional[bool] = False,
) -> TransformerOutput:
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
full_s_max = () if output_s_max else None
attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask(
effective_backend=self.attention_backend,
batch_size=x.shape[0],
seq_len=x.shape[1],
device=x.device,
attention_mask=attention_mask,
)
for block in self.blocks:
if self.gradient_checkpointing and self.training:
x, attn_weights, s_max = self._gradient_checkpointing_func(
block.__call__,
x=x,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
else:
x, attn_weights, s_max = block(
x=x,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
if attentions is not None:
attentions += (attn_weights,)
if output_hidden_states:
assert hidden_states is not None
hidden_states += (x,)
if full_s_max is not None:
full_s_max += (s_max,)
last_hidden_state = self.norm(x)
if output_hidden_states:
hidden_states += (last_hidden_state,)
return TransformerOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=attentions,
s_max=full_s_max,
)
class PreTrainedESMplusplusModel(PreTrainedModel):
"""
init weights for ESM++ models
"""
config_class = ESMplusplusConfig
base_model_prefix = "esm++"
supports_gradient_checkpointing = True
all_tied_weights_keys = {}
@classmethod
def is_remote_code(cls) -> bool:
# Prevent post-load reinitialization of tensors already loaded from checkpoints.
return True
def _init_weights(self, module):
"""Initialize the weights"""
# HF from_pretrained marks loaded parameters with `_is_hf_initialized`.
# Skip this module if any local parameter is already marked as loaded.
for parameter in module.parameters(recurse=False):
if "_is_hf_initialized" in parameter.__dict__ and parameter.__dict__["_is_hf_initialized"]:
return
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
@property
def attn_backend(self) -> str:
return self.config.attn_backend
@attn_backend.setter
def attn_backend(self, backend: str) -> None:
assert backend in VALID_ATTENTION_BACKENDS, f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
self.config.attn_backend = backend
for module in self.modules():
if isinstance(module, TransformerStack):
module.attn_backend = backend
def _reset_rotary_embeddings(self):
"""Refresh non-persistent rotary buffers after checkpoint loading."""
for module in self.modules():
if isinstance(module, RotaryEmbedding):
module.reset_parameters()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
output_loading_info = bool(kwargs["output_loading_info"]) if "output_loading_info" in kwargs else False
loaded = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
if output_loading_info:
model, loading_info = loaded
model._reset_rotary_embeddings()
return model, loading_info
loaded._reset_rotary_embeddings()
return loaded
@classmethod
def from_pretrained_esm(cls, model_name: str):
"""Load a pretrained ESM++ model."""
if '300' in model_name:
return ESMplusplus_300M()
elif '600' in model_name:
return ESMplusplus_600M()
else:
raise ValueError(f"Invalid model name: {model_name}")
### ESM++ Models
class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
"""
ESM++ model. transformer model with no heads
"""
config_class = ESMplusplusConfig
def __init__(self, config: ESMplusplusConfig, **kwargs):
PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
self.transformer = TransformerStack(
d_model=config.hidden_size,
n_heads=config.num_attention_heads,
n_layers=config.num_hidden_layers,
dropout=config.dropout,
attn_backend=config.attn_backend,
)
self.tokenizer = EsmSequenceTokenizer()
self.init_weights()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(input_ids)
return self.transformer(
x=x,
attention_mask=attention_mask,
output_hidden_states=False,
output_attentions=False,
).last_hidden_state
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_s_max: Optional[bool] = False,
return_dict: Optional[bool] = None,
**kwargs,
) -> ESMplusplusOutput:
assert input_ids is not None or inputs_embeds is not None, "You have to specify either input_ids or inputs_embeds"
assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds at the same time"
if inputs_embeds is None:
x = self.embed(input_ids)
else:
x = inputs_embeds
transformer_output = self.transformer(
x=x,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
return ESMplusplusOutput(
last_hidden_state=transformer_output.last_hidden_state,
hidden_states=transformer_output.hidden_states,
attentions=transformer_output.attentions,
s_max=transformer_output.s_max,
)
class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
"""
ESM++ model for masked language modeling.
Implements the base ESM++ architecture with a masked language modeling head.
"""
config_class = ESMplusplusConfig
def __init__(self, config: ESMplusplusConfig, **kwargs):
PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
self.transformer = TransformerStack(
d_model=config.hidden_size,
n_heads=config.num_attention_heads,
n_layers=config.num_hidden_layers,
dropout=config.dropout,
attn_backend=config.attn_backend,
)
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
self.ce_loss = nn.CrossEntropyLoss()
self.tokenizer = EsmSequenceTokenizer()
self.init_weights()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.sequence_head[-1]
def set_output_embeddings(self, new_embeddings):
self.sequence_head[-1] = new_embeddings
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(input_ids)
return self.transformer(
x=x,
attention_mask=attention_mask,
output_hidden_states=False,
output_attentions=False,
).last_hidden_state
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_s_max: Optional[bool] = False,
return_dict: Optional[bool] = None,
**kwargs,
) -> ESMplusplusOutput:
if inputs_embeds is None:
x = self.embed(input_ids)
else:
x = inputs_embeds
output = self.transformer(
x=x,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
output_s_max=output_s_max,
)
last_hidden_state = output.last_hidden_state
logits = self.sequence_head(last_hidden_state)
loss = None
if labels is not None:
loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
return ESMplusplusOutput(
loss=loss,
logits=logits,
last_hidden_state=last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
s_max=output.s_max,
)
class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
"""
ESM++ model for sequence classification.
Extends the base ESM++ model with a classification head.
"""
def __init__(self, config: ESMplusplusConfig, **kwargs):
ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
self.config = config
self.num_labels = config.num_labels
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
# Large intermediate projections help with sequence classification tasks (*4)
self.mse = nn.MSELoss()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
# if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
pooling_types = kwargs['pooling_types']
else:
pooling_types = ['mean', 'var']
self.pooler = Pooler(pooling_types)
self.init_weights()
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(input_ids)
return self.transformer(
x=x,
attention_mask=attention_mask,
output_hidden_states=False,
output_attentions=False,
).last_hidden_state
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_s_max: Optional[bool] = False,
return_dict: Optional[bool] = None,
**kwargs,
) -> ESMplusplusOutput:
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
)
last_hidden_state = output.last_hidden_state
features = self.pooler(last_hidden_state, attention_mask)
logits = self.classifier(features)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = self.mse(logits.flatten(), labels.flatten())
else:
loss = self.mse(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = self.bce(logits, labels)
return ESMplusplusOutput(
loss=loss,
logits=logits,
last_hidden_state=last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
s_max=output.s_max,
)
class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
"""
ESM++ model for token classification.
Extends the base ESM++ model with a token classification head.
"""
def __init__(self, config: ESMplusplusConfig, **kwargs):
ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
self.config = config
self.num_labels = config.num_labels
self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
# Large intermediate projections help with sequence classification tasks (*4)
self.loss_fct = nn.CrossEntropyLoss()
self.init_weights()
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.embed(input_ids)
return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_s_max: Optional[bool] = False,
return_dict: Optional[bool] = None,
**kwargs,
) -> ESMplusplusOutput:
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_s_max=output_s_max,
)
last_hidden_state = output.last_hidden_state
logits = self.classifier(last_hidden_state)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return ESMplusplusOutput(
loss=loss,
logits=logits,
last_hidden_state=last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
s_max=output.s_max,
)
### Loading from EvolutionaryScale
_ESMC_CHECKPOINT_SPECS = {
"esmc-300": {
"repo_id": "EvolutionaryScale/esmc-300m-2024-12",
"weights_relpath": "data/weights/esmc_300m_2024_12_v0.pth",
"hidden_size": 960,
"num_attention_heads": 15,
"num_hidden_layers": 30,
},
"esmc-600": {
"repo_id": "EvolutionaryScale/esmc-600m-2024-12",
"weights_relpath": "data/weights/esmc_600m_2024_12_v0.pth",
"hidden_size": 1152,
"num_attention_heads": 18,
"num_hidden_layers": 36,
},
}
def _resolve_esmc_checkpoint_key(model: str) -> str:
if "esmc-300" in model:
return "esmc-300"
if "esmc-600" in model:
return "esmc-600"
raise ValueError(f"{model=} is an invalid ESMC model name.")
@staticmethod
@cache
def data_root(model: str):
if "INFRA_PROVIDER" in os.environ:
return Path("")
key = _resolve_esmc_checkpoint_key(model)
return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"]))
def get_esmc_checkpoint_path(model: str) -> Path:
key = _resolve_esmc_checkpoint_key(model)
return data_root(key) / _ESMC_CHECKPOINT_SPECS[key]["weights_relpath"]
def _load_esmc_checkpoint_model(
config: ESMplusplusConfig,
model: str,
device: torch.device | str = "cpu",
) -> ESMplusplusForMaskedLM:
key = _resolve_esmc_checkpoint_key(model)
spec = _ESMC_CHECKPOINT_SPECS[key]
assert config.hidden_size == spec["hidden_size"], (
f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, "
f"but got {config.hidden_size}."
)
assert config.num_attention_heads == spec["num_attention_heads"], (
f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, "
f"but got {config.num_attention_heads}."
)
assert config.num_hidden_layers == spec["num_hidden_layers"], (
f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, "
f"but got {config.num_hidden_layers}."
)
with torch.device(device):
model_obj = ESMplusplusForMaskedLM(config)
state_dict = torch.load(get_esmc_checkpoint_path(key), map_location=device)
model_obj.load_state_dict(state_dict)
return model_obj
def ESMplusplus_300M(device: torch.device | str = "cpu"):
config = ESMplusplusConfig(
hidden_size=960,
num_attention_heads=15,
num_hidden_layers=30,
)
return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
def ESMplusplus_600M(device: torch.device | str = "cpu"):
config = ESMplusplusConfig(
hidden_size=1152,
num_attention_heads=18,
num_hidden_layers=36,
)
return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device)
### Tokenization
SEQUENCE_VOCAB = [
"<cls>", "<pad>", "<eos>", "<unk>",
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
"Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
"O", ".", "-", "|",
"<mask>",
]
class EsmSequenceTokenizer(PreTrainedTokenizerFast):
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
unk_token="<unk>",
cls_token="<cls>",
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
chain_break_token="|",
**kwargs,
):
all_tokens = SEQUENCE_VOCAB
token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
# a character-level tokenizer is the same as BPE with no token merges
bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
tokenizer = Tokenizer(bpe)
special_tokens = [
cls_token,
pad_token,
mask_token,
eos_token,
chain_break_token,
]
self.cb_token = chain_break_token
additional_special_tokens = [chain_break_token]
tokenizer.add_special_tokens(special_tokens)
# This is where we configure the automatic addition of special tokens when we call
# tokenizer(text, add_special_tokens=True). Note that you can also configure how two
# sequences are merged if you want.
tokenizer.post_processor = TemplateProcessing( # type: ignore
single="<cls> $A <eos>",
pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
special_tokens=[
("<cls>", tokenizer.token_to_id("<cls>")),
("<eos>", tokenizer.token_to_id("<eos>")),
],
)
super().__init__(
tokenizer_object=tokenizer,
unk_token=unk_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
# These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
@property
def bos_token(self):
return self.cls_token
@property
def bos_token_id(self):
return self.cls_token_id
@property
def chain_break_token(self):
return self.cb_token
@property
def chain_break_token_id(self):
return self.convert_tokens_to_ids(self.chain_break_token)
@property
def all_token_ids(self):
return list(range(self.vocab_size))
@property
def special_token_ids(self):
return self.all_special_ids
if __name__ == "__main__":
import random
import torch
from torch import Tensor
def print_tensor_shapes(prefix: str, obj):
if isinstance(obj, Tensor):
print(f"{prefix}{obj.shape}")
elif isinstance(obj, dict):
for name, value in obj.items():
print_tensor_shapes(f"{prefix}{name}.", value)
elif isinstance(obj, list):
for idx, value in enumerate(obj):
print_tensor_shapes(f"{prefix}[{idx}].", value)
elif isinstance(obj, tuple):
for idx, value in enumerate(obj):
print_tensor_shapes(f"{prefix}[{idx}].", value)
elif hasattr(obj, "__dict__"):
for name, value in vars(obj).items():
if name.startswith("_"):
continue
print_tensor_shapes(f"{prefix}{name}.", value)
else:
print(f"{prefix}{type(obj)}")
random.seed(0)
torch.manual_seed(0)
tokenizer = EsmSequenceTokenizer()
num_attention_heads = random.choice([2, 4])
config = ESMplusplusConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=16 * num_attention_heads,
num_attention_heads=num_attention_heads,
num_hidden_layers=random.choice([1, 2]),
num_labels=2,
dropout=0.0,
)
batch = tokenizer(["ACDEFG", "MKTW"], return_tensors="pt", padding=True)
batch["labels"] = batch["input_ids"].clone()
model = ESMplusplusForMaskedLM(config=config).eval()
with torch.no_grad():
output = model(**batch, return_dict=True)
print("Batch shape:")
print_tensor_shapes("", batch)
print("Output shape:")
print_tensor_shapes("", output)