|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import glob |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
def read_json(p: str) -> Dict: |
|
|
with open(p, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def write_json(p: Path, data: Dict): |
|
|
with open(p, "w") as f: |
|
|
json.dump(data, f, indent=2) |
|
|
|
|
|
|
|
|
def ensure_local(model_or_path: str) -> str: |
|
|
if os.path.isdir(model_or_path): |
|
|
return model_or_path |
|
|
print(f"Downloading {model_or_path} ...") |
|
|
return snapshot_download( |
|
|
model_or_path, cache_dir="./model_cache", resume_download=True |
|
|
) |
|
|
|
|
|
|
|
|
def index_dir(model_dir: str) -> Tuple[Dict[str, str], List[str]]: |
|
|
idx_path = os.path.join(model_dir, "model.safetensors.index.json") |
|
|
weight_map: Dict[str, str] = {} |
|
|
files: List[str] = [] |
|
|
if os.path.exists(idx_path): |
|
|
idx = read_json(idx_path) |
|
|
weight_map = idx.get("weight_map", {}) |
|
|
files = sorted(list({os.path.join(model_dir, f) for f in weight_map.values()})) |
|
|
return weight_map, files |
|
|
|
|
|
st_files = glob.glob(os.path.join(model_dir, "*.safetensors")) |
|
|
if not st_files: |
|
|
raise FileNotFoundError(f"No .safetensors found in {model_dir}") |
|
|
for fpath in st_files: |
|
|
with safe_open(fpath, framework="pt") as f: |
|
|
for k in f.keys(): |
|
|
weight_map[k] = os.path.basename(fpath) |
|
|
files = sorted(st_files) |
|
|
return weight_map, files |
|
|
|
|
|
|
|
|
def parse_layers(spec: str) -> List[int]: |
|
|
out: List[int] = [] |
|
|
for chunk in spec.split(","): |
|
|
chunk = chunk.strip() |
|
|
if not chunk: |
|
|
continue |
|
|
if "-" in chunk: |
|
|
a, b = chunk.split("-") |
|
|
a, b = int(a), int(b) |
|
|
out.extend(list(range(a, b + 1))) |
|
|
else: |
|
|
out.append(int(chunk)) |
|
|
return sorted(list({x for x in out})) |
|
|
|
|
|
|
|
|
def layer_prefix(li: int) -> str: |
|
|
return f"model.layers.{li}." |
|
|
|
|
|
|
|
|
def map_layer(dst_idx: int, dst_total: int, src_total: int, mode: str) -> int: |
|
|
if src_total <= 0: |
|
|
raise ValueError("src_total must be > 0") |
|
|
if mode == "wrap": |
|
|
return dst_idx % src_total |
|
|
x = int(math.floor(dst_idx * src_total / max(1, dst_total))) |
|
|
return max(0, min(src_total - 1, x)) |
|
|
|
|
|
|
|
|
def build_explicit_map(pairs: Optional[str]) -> Dict[int, int]: |
|
|
m: Dict[int, int] = {} |
|
|
if not pairs: |
|
|
return m |
|
|
for token in pairs.split(","): |
|
|
token = token.strip() |
|
|
if not token: |
|
|
continue |
|
|
a, b = token.split(":") |
|
|
m[int(a)] = int(b) |
|
|
return m |
|
|
|
|
|
|
|
|
SCALE_KEYS = { |
|
|
"attn_q": ".self_attn.q_proj.weight", |
|
|
"attn_k": ".self_attn.k_proj.weight", |
|
|
"attn_v": ".self_attn.v_proj.weight", |
|
|
"attn_o": ".self_attn.o_proj.weight", |
|
|
"mlp_up": ".mlp.up_proj.weight", |
|
|
"mlp_gate": ".mlp.gate_proj.weight", |
|
|
"mlp_down": ".mlp.down_proj.weight", |
|
|
} |
|
|
|
|
|
|
|
|
def load_scales(scale_json: Optional[str]) -> Dict[int, Dict[str, float]]: |
|
|
if not scale_json: |
|
|
return {} |
|
|
data = read_json(scale_json) |
|
|
out: Dict[int, Dict[str, float]] = {} |
|
|
for k, v in data.items(): |
|
|
li = int(k) |
|
|
out[li] = {} |
|
|
for mk, sf in v.items(): |
|
|
if mk not in SCALE_KEYS: |
|
|
raise ValueError(f"Unknown scale key '{mk}'. Valid: {list(SCALE_KEYS)}") |
|
|
out[li][mk] = float(sf) |
|
|
return out |
|
|
|
|
|
|
|
|
def tensor_layer_idx(tensor_name: str) -> Optional[int]: |
|
|
parts = tensor_name.split(".") |
|
|
if len(parts) > 3 and parts[0] == "model" and parts[1] == "layers": |
|
|
try: |
|
|
return int(parts[2]) |
|
|
except Exception: |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def apply_scales_if_needed( |
|
|
tname: str, tensor: torch.Tensor, li: int, scales: Dict[int, Dict[str, float]] |
|
|
) -> torch.Tensor: |
|
|
if li not in scales: |
|
|
return tensor |
|
|
spec = scales[li] |
|
|
for key, suffix in SCALE_KEYS.items(): |
|
|
if key in spec and tname.endswith(suffix): |
|
|
s = spec[key] |
|
|
return (tensor * tensor.new_tensor(s)).contiguous() |
|
|
return tensor |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser( |
|
|
description="Layer surgery on safetensors: replace and/or rescale layers." |
|
|
) |
|
|
ap.add_argument("--composite", type=str, required=True) |
|
|
ap.add_argument("--base", type=str, help="Donor model dir or HF ID") |
|
|
ap.add_argument("--out", type=str, required=True) |
|
|
ap.add_argument("--replace_layers", type=str, help='e.g. "61" or "48-55,60,62"') |
|
|
ap.add_argument( |
|
|
"--map", type=str, default="ratio", choices=["ratio", "wrap"] |
|
|
) |
|
|
ap.add_argument("--map_pairs", type=str, help='e.g. "61:34,55:30"') |
|
|
ap.add_argument("--scale_json", type=str) |
|
|
args = ap.parse_args() |
|
|
|
|
|
comp_dir = ensure_local(args.composite) |
|
|
out_dir = Path(args.out) |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
comp_cfg = read_json(os.path.join(comp_dir, "config.json")) |
|
|
L_comp = int(comp_cfg.get("num_hidden_layers")) |
|
|
print(f"Composite layers: {L_comp}") |
|
|
|
|
|
replace_set: List[int] = [] |
|
|
if args.replace_layers: |
|
|
replace_set = parse_layers(args.replace_layers) |
|
|
if not args.base: |
|
|
raise ValueError("--base is required when --replace_layers is set.") |
|
|
base_dir = ensure_local(args.base) |
|
|
base_cfg = read_json(os.path.join(base_dir, "config.json")) |
|
|
L_base = int(base_cfg.get("num_hidden_layers")) |
|
|
print(f"Donor layers: {L_base}") |
|
|
explicit = build_explicit_map(args.map_pairs) |
|
|
else: |
|
|
base_dir = "" |
|
|
L_base = 0 |
|
|
explicit = {} |
|
|
|
|
|
comp_map, comp_files = index_dir(comp_dir) |
|
|
if replace_set: |
|
|
base_map, base_files = index_dir(base_dir) |
|
|
else: |
|
|
base_map, base_files = {}, [] |
|
|
|
|
|
scales = load_scales(args.scale_json) |
|
|
if scales: |
|
|
print("Scales loaded for layers:", sorted(scales.keys())) |
|
|
|
|
|
to_copy = [ |
|
|
"config.json", |
|
|
"tokenizer.json", |
|
|
"tokenizer_config.json", |
|
|
"special_tokens_map.json", |
|
|
"vocab.json", |
|
|
"merges.txt", |
|
|
"tokenizer.model", |
|
|
"generation_config.json", |
|
|
] |
|
|
for fname in to_copy: |
|
|
src = os.path.join(comp_dir, fname) |
|
|
if os.path.exists(src): |
|
|
shutil.copy2(src, out_dir / fname) |
|
|
|
|
|
print("Performing surgery shard-by-shard...") |
|
|
out_weight_map: Dict[str, str] = {} |
|
|
for comp_f in comp_files: |
|
|
rel = os.path.basename(comp_f) |
|
|
out_f = out_dir / rel |
|
|
new_tensors: Dict[str, torch.Tensor] = {} |
|
|
|
|
|
with safe_open(comp_f, framework="pt") as fcomp: |
|
|
keys = list(fcomp.keys()) |
|
|
for k in keys: |
|
|
li = tensor_layer_idx(k) |
|
|
tensor = None |
|
|
|
|
|
if li is not None and li in replace_set: |
|
|
if li in explicit: |
|
|
src_li = explicit[li] |
|
|
else: |
|
|
src_li = map_layer(li, L_comp, L_base, args.map) |
|
|
src_prefix = layer_prefix(src_li) |
|
|
dst_prefix = layer_prefix(li) |
|
|
donor_k = src_prefix + k[len(dst_prefix) :] |
|
|
|
|
|
donor_file = base_map.get(donor_k) |
|
|
if donor_file is None: |
|
|
raise KeyError(f"Donor tensor not found: {donor_k}") |
|
|
donor_path = os.path.join(base_dir, donor_file) |
|
|
with safe_open(donor_path, framework="pt") as fbase: |
|
|
tensor = fbase.get_tensor(donor_k) |
|
|
else: |
|
|
tensor = fcomp.get_tensor(k) |
|
|
|
|
|
if li is not None: |
|
|
tensor = apply_scales_if_needed(k, tensor, li, scales) |
|
|
|
|
|
if not tensor.is_contiguous(): |
|
|
tensor = tensor.contiguous() |
|
|
new_tensors[k] = tensor |
|
|
out_weight_map[k] = rel |
|
|
|
|
|
save_file(new_tensors, str(out_f)) |
|
|
|
|
|
total_size = 0 |
|
|
for fname in set(out_weight_map.values()): |
|
|
fp = out_dir / fname |
|
|
if fp.exists(): |
|
|
total_size += fp.stat().st_size |
|
|
index = {"metadata": {"total_size": total_size, "format": "safetensors"}, "weight_map": out_weight_map} |
|
|
write_json(out_dir / "model.safetensors.index.json", index) |
|
|
print(f"Done. Wrote modified shards and index to: {out_dir}") |
|
|
|
|
|
print("\nTip: validate load quickly (meta device):") |
|
|
print(f" from transformers import AutoModelForCausalLM") |
|
|
print(f" AutoModelForCausalLM.from_pretrained('{str(out_dir)}', device_map='meta', trust_remote_code=True)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|