snwy's picture
repro code
aee6a1a verified
#!/usr/bin/env python3
# fmt: off
import argparse
import json
import math
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@dataclass
class RunningStat:
count: int = 0
sum: float = 0.0
sumsq: float = 0.0
min: Optional[float] = None
max: Optional[float] = None
zero_count: int = 0
nan_count: int = 0
inf_count: int = 0
def update_from_tensor(self, t: torch.Tensor):
with torch.no_grad():
nan_mask = torch.isnan(t)
inf_mask = torch.isinf(t)
self.nan_count += int(nan_mask.sum().item())
self.inf_count += int(inf_mask.sum().item())
t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0)
self.zero_count += int((t == 0).sum().item())
tf = t.float()
self.sum += float(tf.sum().item())
self.sumsq += float((tf * tf).sum().item())
self.count += t.numel()
t_min = float(t.min().item())
t_max = float(t.max().item())
if self.min is None or t_min < self.min:
self.min = t_min
if self.max is None or t_max > self.max:
self.max = t_max
@property
def mean(self) -> Optional[float]:
if self.count == 0:
return None
return self.sum / self.count
@property
def var(self) -> Optional[float]:
if self.count == 0:
return None
m = self.mean
return max(0.0, self.sumsq / self.count - (m * m))
@property
def std(self) -> Optional[float]:
v = self.var
if v is None:
return None
return math.sqrt(v)
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["mean"] = self.mean
d["std"] = self.std
return d
@dataclass
class TokenRMSStat:
count: int = 0
sum: float = 0.0
sumsq: float = 0.0
def update_from_tensor(self, t: torch.Tensor):
with torch.no_grad():
if t.ndim == 1:
feats = t.unsqueeze(0)
else:
feats = t.view(-1, t.shape[-1])
rms = feats.float().pow(2).mean(dim=-1).sqrt()
rms = torch.nan_to_num(rms, nan=0.0, posinf=0.0, neginf=0.0)
self.count += int(rms.numel())
self.sum += float(rms.sum().item())
self.sumsq += float((rms * rms).sum().item())
@property
def mean(self) -> Optional[float]:
if self.count == 0:
return None
return self.sum / self.count
@property
def var(self) -> Optional[float]:
if self.count == 0:
return None
m = self.mean
return max(0.0, self.sumsq / self.count - (m * m))
@property
def std(self) -> Optional[float]:
v = self.var
if v is None:
return None
return math.sqrt(v)
def to_dict(self) -> Dict[str, Any]:
return {
"count": self.count,
"mean": self.mean,
"std": self.std,
}
class ActivationMonitor:
def __init__(self, use_tensorboard: bool = False, tb_dir: Optional[str] = None):
self.stats: Dict[str, RunningStat] = {}
self.token_rms: Dict[str, TokenRMSStat] = {}
self.use_tensorboard = use_tensorboard
self.tb = None
self._global_step = 0
if self.use_tensorboard and tb_dir is not None:
try:
from torch.utils.tensorboard import SummaryWriter
self.tb = SummaryWriter(log_dir=tb_dir)
except Exception as e:
print(f"TensorBoard not available: {e}")
def _get_stat(self, name: str) -> RunningStat:
if name not in self.stats:
self.stats[name] = RunningStat()
return self.stats[name]
def _get_token_rms(self, name: str) -> TokenRMSStat:
if name not in self.token_rms:
self.token_rms[name] = TokenRMSStat()
return self.token_rms[name]
def hook(self, name: str):
def _hook(module, inputs, output):
with torch.no_grad():
t = output
if isinstance(t, tuple):
t = t[0]
if not isinstance(t, torch.Tensor):
return
self._get_stat(name).update_from_tensor(t)
self._get_token_rms(name).update_from_tensor(t)
if self.tb is not None and (self._global_step % 10 == 0):
rs = self.stats[name]
tr = self.token_rms[name]
if rs.count > 0:
self.tb.add_scalar(
f"{name}/mean", rs.mean, self._global_step
)
if rs.std is not None:
self.tb.add_scalar(
f"{name}/std", rs.std, self._global_step
)
self.tb.add_scalar(
f"{name}/zero_frac",
rs.zero_count / max(1, rs.count),
self._global_step,
)
if tr.count > 0 and tr.mean is not None:
self.tb.add_scalar(
f"{name}/token_rms_mean",
tr.mean,
self._global_step,
)
return
return _hook
def step(self):
self._global_step += 1
def close(self):
if self.tb is not None:
self.tb.flush()
self.tb.close()
def to_dict(self) -> Dict[str, Any]:
out: Dict[str, Any] = {}
for k in sorted(self.stats.keys()):
out[k] = {
"global": self.stats[k].to_dict(),
"token_rms": self.token_rms[k].to_dict(),
}
return out
def find_modules_to_hook(
model: torch.nn.Module, patterns: List[str]
) -> List[str]:
names: List[str] = []
for name, _ in model.named_modules():
lname = name.lower()
if not lname.startswith("model.layers."):
continue
for p in patterns:
if p in lname:
names.append(name)
break
return sorted(list(set(names)))
def compute_attention_entropy(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
prompts: List[str],
max_length: int,
input_device: torch.device,
) -> Dict[int, float]:
prev = getattr(model.config, "output_attentions", False)
model.config.output_attentions = True
with torch.inference_mode():
enc = tok(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
for k in enc:
enc[k] = enc[k].to(input_device)
out = model(**enc, output_attentions=True, use_cache=False)
atts = out.attentions
entropies: Dict[int, float] = {}
for i, att in enumerate(atts):
probs = att.float().clamp_min(1e-12)
ent = -(probs * probs.log()).sum(dim=-1)
ent_mean = float(ent.mean().item())
entropies[i] = ent_mean
model.config.output_attentions = prev
return entropies
def load_prompts(
prompts: Optional[str], prompts_file: Optional[str]
) -> List[str]:
lines: List[str] = []
if prompts_file:
with open(prompts_file, "r", encoding="utf-8") as f:
for line in f:
s = line.strip("\n")
if s:
lines.append(s)
if prompts:
for s in prompts.split("\n"):
s = s.strip()
if s:
lines.append(s)
if not lines:
lines = [
"Hello! Briefly introduce yourself.",
"Explain the concept of attention in transformers.",
"List three use cases for large language models.",
]
return lines
def main():
ap = argparse.ArgumentParser(
description="Activation statistics monitor for HF CausalLM models."
)
ap.add_argument("--model", type=str, required=True, help="Model path or HF ID.")
ap.add_argument("--prompts", type=str)
ap.add_argument("--prompts_file", type=str)
ap.add_argument("--max_length", type=int, default=256)
ap.add_argument("--batch_size", type=int, default=4)
ap.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
)
ap.add_argument("--device_map", type=str, default="auto")
ap.add_argument(
"--patterns",
type=str,
default=(
"q_proj,k_proj,v_proj,o_proj,mlp.up_proj,mlp.gate_proj,"
"mlp.down_proj,layernorm,norm"
),
)
ap.add_argument("--save_json", type=str)
ap.add_argument("--tensorboard_dir", type=str)
ap.add_argument("--attention_entropy", action="store_true")
args = ap.parse_args()
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading tokenizer/model: {args.model}")
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch_dtype,
trust_remote_code=True,
device_map=args.device_map,
)
model.eval()
embed_device = model.get_input_embeddings().weight.device
print(f"Sending inputs to: {embed_device}")
patterns = [p.strip().lower() for p in args.patterns.split(",") if p.strip()]
to_hook = find_modules_to_hook(model, patterns)
mon = ActivationMonitor(
use_tensorboard=args.tensorboard_dir is not None,
tb_dir=args.tensorboard_dir,
)
handles = []
for name, module in model.named_modules():
if name in to_hook:
handles.append(module.register_forward_hook(mon.hook(name)))
print(f"Registered hooks on {len(handles)} modules.")
prompts = load_prompts(args.prompts, args.prompts_file)
with torch.inference_mode():
i = 0
while i < len(prompts):
batch_prompts = prompts[i : i + args.batch_size]
i += args.batch_size
enc = tok(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=args.max_length,
)
for k in enc:
enc[k] = enc[k].to(embed_device)
_ = model(**enc, use_cache=False)
mon.step()
attn_entropy: Dict[int, float] = {}
if args.attention_entropy:
subset = prompts[: min(len(prompts), args.batch_size)]
attn_entropy = compute_attention_entropy(
model, tok, subset, args.max_length, embed_device
)
for h in handles:
h.remove()
mon.close()
stats = mon.to_dict()
if args.attention_entropy:
stats["_attention_entropy"] = attn_entropy
print("\nActivation summary (top 10 by token_rms mean):")
ranked = sorted(
[
(name, d["token_rms"]["mean"] or 0.0)
for name, d in stats.items()
if name != "_attention_entropy"
],
key=lambda x: x[1],
reverse=True,
)[:10]
for name, rms_mean in ranked:
g = stats[name]["global"]
zero_frac = g.get("zero_count", 0) / max(1, g.get("count", 1))
print(
f"- {name}: token_rms_mean={rms_mean:.4f}, "
f"mean={g.get('mean'):.4f} std={g.get('std'):.4f} "
f"min={g.get('min'):.4f} max={g.get('max'):.4f} "
f"zero_frac={zero_frac:.4f}"
)
if args.save_json:
out_path = Path(args.save_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(stats, f, indent=2)
print(f"\nSaved stats JSON to: {out_path}")
if __name__ == "__main__":
main()