|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import math |
|
|
from dataclasses import dataclass, asdict |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunningStat: |
|
|
count: int = 0 |
|
|
sum: float = 0.0 |
|
|
sumsq: float = 0.0 |
|
|
min: Optional[float] = None |
|
|
max: Optional[float] = None |
|
|
zero_count: int = 0 |
|
|
nan_count: int = 0 |
|
|
inf_count: int = 0 |
|
|
|
|
|
def update_from_tensor(self, t: torch.Tensor): |
|
|
with torch.no_grad(): |
|
|
nan_mask = torch.isnan(t) |
|
|
inf_mask = torch.isinf(t) |
|
|
self.nan_count += int(nan_mask.sum().item()) |
|
|
self.inf_count += int(inf_mask.sum().item()) |
|
|
t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
|
|
|
self.zero_count += int((t == 0).sum().item()) |
|
|
|
|
|
tf = t.float() |
|
|
self.sum += float(tf.sum().item()) |
|
|
self.sumsq += float((tf * tf).sum().item()) |
|
|
self.count += t.numel() |
|
|
|
|
|
t_min = float(t.min().item()) |
|
|
t_max = float(t.max().item()) |
|
|
if self.min is None or t_min < self.min: |
|
|
self.min = t_min |
|
|
if self.max is None or t_max > self.max: |
|
|
self.max = t_max |
|
|
|
|
|
@property |
|
|
def mean(self) -> Optional[float]: |
|
|
if self.count == 0: |
|
|
return None |
|
|
return self.sum / self.count |
|
|
|
|
|
@property |
|
|
def var(self) -> Optional[float]: |
|
|
if self.count == 0: |
|
|
return None |
|
|
m = self.mean |
|
|
return max(0.0, self.sumsq / self.count - (m * m)) |
|
|
|
|
|
@property |
|
|
def std(self) -> Optional[float]: |
|
|
v = self.var |
|
|
if v is None: |
|
|
return None |
|
|
return math.sqrt(v) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
d = asdict(self) |
|
|
d["mean"] = self.mean |
|
|
d["std"] = self.std |
|
|
return d |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenRMSStat: |
|
|
count: int = 0 |
|
|
sum: float = 0.0 |
|
|
sumsq: float = 0.0 |
|
|
|
|
|
def update_from_tensor(self, t: torch.Tensor): |
|
|
with torch.no_grad(): |
|
|
if t.ndim == 1: |
|
|
feats = t.unsqueeze(0) |
|
|
else: |
|
|
feats = t.view(-1, t.shape[-1]) |
|
|
rms = feats.float().pow(2).mean(dim=-1).sqrt() |
|
|
rms = torch.nan_to_num(rms, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
self.count += int(rms.numel()) |
|
|
self.sum += float(rms.sum().item()) |
|
|
self.sumsq += float((rms * rms).sum().item()) |
|
|
|
|
|
@property |
|
|
def mean(self) -> Optional[float]: |
|
|
if self.count == 0: |
|
|
return None |
|
|
return self.sum / self.count |
|
|
|
|
|
@property |
|
|
def var(self) -> Optional[float]: |
|
|
if self.count == 0: |
|
|
return None |
|
|
m = self.mean |
|
|
return max(0.0, self.sumsq / self.count - (m * m)) |
|
|
|
|
|
@property |
|
|
def std(self) -> Optional[float]: |
|
|
v = self.var |
|
|
if v is None: |
|
|
return None |
|
|
return math.sqrt(v) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"count": self.count, |
|
|
"mean": self.mean, |
|
|
"std": self.std, |
|
|
} |
|
|
|
|
|
|
|
|
class ActivationMonitor: |
|
|
def __init__(self, use_tensorboard: bool = False, tb_dir: Optional[str] = None): |
|
|
self.stats: Dict[str, RunningStat] = {} |
|
|
self.token_rms: Dict[str, TokenRMSStat] = {} |
|
|
self.use_tensorboard = use_tensorboard |
|
|
self.tb = None |
|
|
self._global_step = 0 |
|
|
if self.use_tensorboard and tb_dir is not None: |
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
self.tb = SummaryWriter(log_dir=tb_dir) |
|
|
except Exception as e: |
|
|
print(f"TensorBoard not available: {e}") |
|
|
|
|
|
def _get_stat(self, name: str) -> RunningStat: |
|
|
if name not in self.stats: |
|
|
self.stats[name] = RunningStat() |
|
|
return self.stats[name] |
|
|
|
|
|
def _get_token_rms(self, name: str) -> TokenRMSStat: |
|
|
if name not in self.token_rms: |
|
|
self.token_rms[name] = TokenRMSStat() |
|
|
return self.token_rms[name] |
|
|
|
|
|
def hook(self, name: str): |
|
|
def _hook(module, inputs, output): |
|
|
with torch.no_grad(): |
|
|
t = output |
|
|
if isinstance(t, tuple): |
|
|
t = t[0] |
|
|
if not isinstance(t, torch.Tensor): |
|
|
return |
|
|
self._get_stat(name).update_from_tensor(t) |
|
|
self._get_token_rms(name).update_from_tensor(t) |
|
|
|
|
|
if self.tb is not None and (self._global_step % 10 == 0): |
|
|
rs = self.stats[name] |
|
|
tr = self.token_rms[name] |
|
|
if rs.count > 0: |
|
|
self.tb.add_scalar( |
|
|
f"{name}/mean", rs.mean, self._global_step |
|
|
) |
|
|
if rs.std is not None: |
|
|
self.tb.add_scalar( |
|
|
f"{name}/std", rs.std, self._global_step |
|
|
) |
|
|
self.tb.add_scalar( |
|
|
f"{name}/zero_frac", |
|
|
rs.zero_count / max(1, rs.count), |
|
|
self._global_step, |
|
|
) |
|
|
if tr.count > 0 and tr.mean is not None: |
|
|
self.tb.add_scalar( |
|
|
f"{name}/token_rms_mean", |
|
|
tr.mean, |
|
|
self._global_step, |
|
|
) |
|
|
return |
|
|
|
|
|
return _hook |
|
|
|
|
|
def step(self): |
|
|
self._global_step += 1 |
|
|
|
|
|
def close(self): |
|
|
if self.tb is not None: |
|
|
self.tb.flush() |
|
|
self.tb.close() |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
out: Dict[str, Any] = {} |
|
|
for k in sorted(self.stats.keys()): |
|
|
out[k] = { |
|
|
"global": self.stats[k].to_dict(), |
|
|
"token_rms": self.token_rms[k].to_dict(), |
|
|
} |
|
|
return out |
|
|
|
|
|
|
|
|
def find_modules_to_hook( |
|
|
model: torch.nn.Module, patterns: List[str] |
|
|
) -> List[str]: |
|
|
names: List[str] = [] |
|
|
for name, _ in model.named_modules(): |
|
|
lname = name.lower() |
|
|
if not lname.startswith("model.layers."): |
|
|
continue |
|
|
for p in patterns: |
|
|
if p in lname: |
|
|
names.append(name) |
|
|
break |
|
|
return sorted(list(set(names))) |
|
|
|
|
|
|
|
|
def compute_attention_entropy( |
|
|
model: AutoModelForCausalLM, |
|
|
tok: AutoTokenizer, |
|
|
prompts: List[str], |
|
|
max_length: int, |
|
|
input_device: torch.device, |
|
|
) -> Dict[int, float]: |
|
|
prev = getattr(model.config, "output_attentions", False) |
|
|
model.config.output_attentions = True |
|
|
|
|
|
with torch.inference_mode(): |
|
|
enc = tok( |
|
|
prompts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
) |
|
|
for k in enc: |
|
|
enc[k] = enc[k].to(input_device) |
|
|
out = model(**enc, output_attentions=True, use_cache=False) |
|
|
atts = out.attentions |
|
|
entropies: Dict[int, float] = {} |
|
|
for i, att in enumerate(atts): |
|
|
probs = att.float().clamp_min(1e-12) |
|
|
ent = -(probs * probs.log()).sum(dim=-1) |
|
|
ent_mean = float(ent.mean().item()) |
|
|
entropies[i] = ent_mean |
|
|
|
|
|
model.config.output_attentions = prev |
|
|
return entropies |
|
|
|
|
|
|
|
|
def load_prompts( |
|
|
prompts: Optional[str], prompts_file: Optional[str] |
|
|
) -> List[str]: |
|
|
lines: List[str] = [] |
|
|
if prompts_file: |
|
|
with open(prompts_file, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
s = line.strip("\n") |
|
|
if s: |
|
|
lines.append(s) |
|
|
if prompts: |
|
|
for s in prompts.split("\n"): |
|
|
s = s.strip() |
|
|
if s: |
|
|
lines.append(s) |
|
|
if not lines: |
|
|
lines = [ |
|
|
"Hello! Briefly introduce yourself.", |
|
|
"Explain the concept of attention in transformers.", |
|
|
"List three use cases for large language models.", |
|
|
] |
|
|
return lines |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser( |
|
|
description="Activation statistics monitor for HF CausalLM models." |
|
|
) |
|
|
ap.add_argument("--model", type=str, required=True, help="Model path or HF ID.") |
|
|
ap.add_argument("--prompts", type=str) |
|
|
ap.add_argument("--prompts_file", type=str) |
|
|
ap.add_argument("--max_length", type=int, default=256) |
|
|
ap.add_argument("--batch_size", type=int, default=4) |
|
|
ap.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="bfloat16", |
|
|
choices=["bfloat16", "float16", "float32"], |
|
|
) |
|
|
ap.add_argument("--device_map", type=str, default="auto") |
|
|
ap.add_argument( |
|
|
"--patterns", |
|
|
type=str, |
|
|
default=( |
|
|
"q_proj,k_proj,v_proj,o_proj,mlp.up_proj,mlp.gate_proj," |
|
|
"mlp.down_proj,layernorm,norm" |
|
|
), |
|
|
) |
|
|
ap.add_argument("--save_json", type=str) |
|
|
ap.add_argument("--tensorboard_dir", type=str) |
|
|
ap.add_argument("--attention_entropy", action="store_true") |
|
|
args = ap.parse_args() |
|
|
|
|
|
dtype_map = { |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float16": torch.float16, |
|
|
"float32": torch.float32, |
|
|
} |
|
|
torch_dtype = dtype_map[args.dtype] |
|
|
|
|
|
print(f"Loading tokenizer/model: {args.model}") |
|
|
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model, |
|
|
torch_dtype=torch_dtype, |
|
|
trust_remote_code=True, |
|
|
device_map=args.device_map, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
embed_device = model.get_input_embeddings().weight.device |
|
|
print(f"Sending inputs to: {embed_device}") |
|
|
|
|
|
patterns = [p.strip().lower() for p in args.patterns.split(",") if p.strip()] |
|
|
to_hook = find_modules_to_hook(model, patterns) |
|
|
|
|
|
mon = ActivationMonitor( |
|
|
use_tensorboard=args.tensorboard_dir is not None, |
|
|
tb_dir=args.tensorboard_dir, |
|
|
) |
|
|
handles = [] |
|
|
for name, module in model.named_modules(): |
|
|
if name in to_hook: |
|
|
handles.append(module.register_forward_hook(mon.hook(name))) |
|
|
print(f"Registered hooks on {len(handles)} modules.") |
|
|
|
|
|
prompts = load_prompts(args.prompts, args.prompts_file) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
i = 0 |
|
|
while i < len(prompts): |
|
|
batch_prompts = prompts[i : i + args.batch_size] |
|
|
i += args.batch_size |
|
|
enc = tok( |
|
|
batch_prompts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=args.max_length, |
|
|
) |
|
|
for k in enc: |
|
|
enc[k] = enc[k].to(embed_device) |
|
|
_ = model(**enc, use_cache=False) |
|
|
mon.step() |
|
|
|
|
|
attn_entropy: Dict[int, float] = {} |
|
|
if args.attention_entropy: |
|
|
subset = prompts[: min(len(prompts), args.batch_size)] |
|
|
attn_entropy = compute_attention_entropy( |
|
|
model, tok, subset, args.max_length, embed_device |
|
|
) |
|
|
|
|
|
for h in handles: |
|
|
h.remove() |
|
|
mon.close() |
|
|
|
|
|
stats = mon.to_dict() |
|
|
if args.attention_entropy: |
|
|
stats["_attention_entropy"] = attn_entropy |
|
|
|
|
|
print("\nActivation summary (top 10 by token_rms mean):") |
|
|
ranked = sorted( |
|
|
[ |
|
|
(name, d["token_rms"]["mean"] or 0.0) |
|
|
for name, d in stats.items() |
|
|
if name != "_attention_entropy" |
|
|
], |
|
|
key=lambda x: x[1], |
|
|
reverse=True, |
|
|
)[:10] |
|
|
for name, rms_mean in ranked: |
|
|
g = stats[name]["global"] |
|
|
zero_frac = g.get("zero_count", 0) / max(1, g.get("count", 1)) |
|
|
print( |
|
|
f"- {name}: token_rms_mean={rms_mean:.4f}, " |
|
|
f"mean={g.get('mean'):.4f} std={g.get('std'):.4f} " |
|
|
f"min={g.get('min'):.4f} max={g.get('max'):.4f} " |
|
|
f"zero_frac={zero_frac:.4f}" |
|
|
) |
|
|
|
|
|
if args.save_json: |
|
|
out_path = Path(args.save_json) |
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(out_path, "w") as f: |
|
|
json.dump(stats, f, indent=2) |
|
|
print(f"\nSaved stats JSON to: {out_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|