| """ |
| Helion-OSC Sharded Model Loader |
| Efficiently loads 116 safetensors shards (2.8GB each) |
| """ |
|
|
| import torch |
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, Optional, List |
| import logging |
| from tqdm import tqdm |
| from safetensors.torch import load_file |
| from transformers import AutoConfig, AutoTokenizer |
| import psutil |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ShardedModelLoader: |
| """ |
| Loader for sharded safetensors model files |
| Optimized for 116 shards of 2.8GB each |
| """ |
| |
| def __init__(self, model_path: str): |
| """ |
| Initialize the sharded model loader |
| |
| Args: |
| model_path: Path to the inference directory containing shards |
| """ |
| self.model_path = Path(model_path) |
| self.config_path = self.model_path / "model_config.json" |
| self.index_path = self.model_path / "model.safetensors.index.json" |
| |
| |
| logger.info(f"Loading configuration from {self.config_path}") |
| with open(self.config_path, 'r') as f: |
| self.config = json.load(f) |
| |
| |
| logger.info(f"Loading weight index from {self.index_path}") |
| with open(self.index_path, 'r') as f: |
| self.index = json.load(f) |
| |
| self.metadata = self.index.get("metadata", {}) |
| self.weight_map = self.index.get("weight_map", {}) |
| |
| logger.info(f"Model: {self.metadata.get('model_type', 'unknown')}") |
| logger.info(f"Total shards: {self.metadata.get('total_shards', 0)}") |
| logger.info(f"Total size: {self.metadata.get('total_size', 0) / 1e9:.2f} GB") |
| logger.info(f"Total parameters: {self.config['architectures_info']['total_parameters']}") |
| logger.info(f"Active parameters: {self.config['architectures_info']['active_parameters']}") |
| |
| def get_shard_path(self, shard_name: str) -> Path: |
| """Get full path to a shard file""" |
| return self.model_path / shard_name |
| |
| def get_available_memory(self) -> Dict[str, float]: |
| """Get available system memory""" |
| memory = psutil.virtual_memory() |
| result = { |
| "ram_total_gb": memory.total / 1e9, |
| "ram_available_gb": memory.available / 1e9, |
| "ram_percent_used": memory.percent |
| } |
| |
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| gpu_mem = torch.cuda.get_device_properties(i).total_memory |
| gpu_allocated = torch.cuda.memory_allocated(i) |
| result[f"gpu_{i}_total_gb"] = gpu_mem / 1e9 |
| result[f"gpu_{i}_available_gb"] = (gpu_mem - gpu_allocated) / 1e9 |
| |
| return result |
| |
| def load_shard(self, shard_name: str, device: str = "cpu") -> Dict[str, torch.Tensor]: |
| """ |
| Load a single shard file |
| |
| Args: |
| shard_name: Name of the shard file |
| device: Device to load tensors to |
| |
| Returns: |
| Dictionary of weight tensors |
| """ |
| shard_path = self.get_shard_path(shard_name) |
| |
| if not shard_path.exists(): |
| raise FileNotFoundError(f"Shard not found: {shard_path}") |
| |
| logger.debug(f"Loading shard: {shard_name}") |
| return load_file(str(shard_path), device=device) |
| |
| def load_sharded_weights( |
| self, |
| device: str = "cpu", |
| low_memory: bool = False, |
| show_progress: bool = True |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Load all sharded weights |
| |
| Args: |
| device: Device to load weights to |
| low_memory: Use memory-efficient loading |
| show_progress: Show progress bar |
| |
| Returns: |
| Dictionary of all model weights |
| """ |
| logger.info("Loading sharded model weights...") |
| |
| |
| mem_info = self.get_available_memory() |
| logger.info(f"Available RAM: {mem_info['ram_available_gb']:.2f} GB") |
| if "gpu_0_available_gb" in mem_info: |
| logger.info(f"Available GPU 0: {mem_info['gpu_0_available_gb']:.2f} GB") |
| |
| |
| shard_files = sorted(set(self.weight_map.values())) |
| total_shards = len(shard_files) |
| |
| logger.info(f"Loading {total_shards} shard files...") |
| |
| all_weights = {} |
| |
| |
| pbar = tqdm(shard_files, disable=not show_progress, desc="Loading shards") |
| |
| for shard_name in pbar: |
| pbar.set_description(f"Loading {shard_name}") |
| |
| |
| shard_weights = self.load_shard(shard_name, device=device) |
| |
| |
| all_weights.update(shard_weights) |
| |
| |
| if low_memory: |
| del shard_weights |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info(f"Loaded {len(all_weights)} weight tensors") |
| return all_weights |
| |
| def get_layer_weights(self, layer_idx: int) -> List[str]: |
| """ |
| Get all weight keys for a specific layer |
| |
| Args: |
| layer_idx: Layer index |
| |
| Returns: |
| List of weight keys for that layer |
| """ |
| prefix = f"model.layers.{layer_idx}." |
| return [k for k in self.weight_map.keys() if k.startswith(prefix)] |
| |
| def get_shard_for_weight(self, weight_key: str) -> Optional[str]: |
| """ |
| Get shard file name for a specific weight |
| |
| Args: |
| weight_key: Weight key/name |
| |
| Returns: |
| Shard file name or None |
| """ |
| return self.weight_map.get(weight_key) |
| |
| def verify_shards(self) -> Dict[str, bool]: |
| """ |
| Verify all shard files exist |
| |
| Returns: |
| Dictionary mapping shard names to existence status |
| """ |
| logger.info("Verifying shard files...") |
| |
| shard_files = set(self.weight_map.values()) |
| verification = {} |
| |
| for shard_name in tqdm(sorted(shard_files), desc="Verifying"): |
| shard_path = self.get_shard_path(shard_name) |
| verification[shard_name] = shard_path.exists() |
| |
| missing = [s for s, exists in verification.items() if not exists] |
| |
| if missing: |
| logger.warning(f"Missing {len(missing)} shard files:") |
| for shard in missing[:10]: |
| logger.warning(f" - {shard}") |
| if len(missing) > 10: |
| logger.warning(f" ... and {len(missing) - 10} more") |
| else: |
| logger.info("✓ All shard files present") |
| |
| return verification |
| |
| def load_metadata(self) -> Dict: |
| """Load model metadata""" |
| return { |
| "config": self.config, |
| "index": self.index, |
| "total_shards": self.metadata.get("total_shards", 0), |
| "total_size_gb": self.metadata.get("total_size", 0) / 1e9, |
| "architecture": self.config.get("architectures_info", {}), |
| "num_layers": self.config.get("num_hidden_layers", 0), |
| "hidden_size": self.config.get("hidden_size", 0), |
| "vocab_size": self.config.get("vocab_size", 0) |
| } |
|
|
|
|
| def load_full_model( |
| model_path: str, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| low_memory: bool = False |
| ): |
| """ |
| Convenience function to load the full model |
| |
| Args: |
| model_path: Path to inference directory |
| device: Device to load model to |
| low_memory: Use low memory loading |
| |
| Returns: |
| Loaded model weights and metadata |
| """ |
| loader = ShardedModelLoader(model_path) |
| |
| |
| verification = loader.verify_shards() |
| missing = sum(1 for exists in verification.values() if not exists) |
| |
| if missing > 0: |
| raise FileNotFoundError( |
| f"Cannot load model: {missing} shard files are missing. " |
| f"Please download all 116 shard files." |
| ) |
| |
| |
| weights = loader.load_sharded_weights( |
| device=device, |
| low_memory=low_memory, |
| show_progress=True |
| ) |
| |
| |
| metadata = loader.load_metadata() |
| |
| return weights, metadata |
|
|
|
|
| def inspect_model(model_path: str): |
| """ |
| Inspect model structure without loading weights |
| |
| Args: |
| model_path: Path to inference directory |
| """ |
| loader = ShardedModelLoader(model_path) |
| |
| print("\n" + "="*80) |
| print("HELION-OSC MODEL INSPECTION") |
| print("="*80) |
| |
| metadata = loader.load_metadata() |
| |
| print(f"\nModel Type: {metadata['architecture'].get('model_description', 'N/A')}") |
| print(f"Architecture: {metadata['architecture'].get('architecture_type', 'N/A')}") |
| print(f"Total Parameters: {metadata['architecture'].get('total_parameters', 'N/A')}") |
| print(f"Active Parameters: {metadata['architecture'].get('active_parameters', 'N/A')}") |
| |
| print(f"\nModel Configuration:") |
| print(f" Layers: {metadata['num_layers']}") |
| print(f" Hidden Size: {metadata['hidden_size']}") |
| print(f" Vocabulary Size: {metadata['vocab_size']}") |
| print(f" Attention Heads: {metadata['config'].get('num_attention_heads', 'N/A')}") |
| print(f" KV Heads: {metadata['config'].get('num_key_value_heads', 'N/A')}") |
| |
| print(f"\nMoE Configuration:") |
| arch = metadata['architecture'] |
| print(f" Number of Experts: {arch.get('num_experts', 'N/A')}") |
| print(f" Experts per Token: {arch.get('experts_per_token', 'N/A')}") |
| print(f" Shared Experts: {arch.get('num_shared_experts', 'N/A')}") |
| |
| print(f"\nStorage Information:") |
| print(f" Total Shards: {metadata['total_shards']}") |
| print(f" Total Size: {metadata['total_size_gb']:.2f} GB") |
| print(f" Shard Size: ~2.8 GB each") |
| print(f" Format: safetensors") |
| print(f" Precision: bfloat16") |
| |
| print(f"\nContext Length:") |
| print(f" Max Position Embeddings: {metadata['config'].get('max_position_embeddings', 'N/A')}") |
| print(f" RoPE Theta: {metadata['config'].get('rope_theta', 'N/A')}") |
| |
| print("\n" + "="*80) |
| |
| |
| print("\nVerifying shard files...") |
| verification = loader.verify_shards() |
| present = sum(1 for exists in verification.values() if exists) |
| total = len(verification) |
| |
| print(f"\nShard Status: {present}/{total} files present") |
| |
| if present == total: |
| print("✓ All shard files are available") |
| else: |
| print(f"✗ Missing {total - present} shard files") |
|
|
|
|
| def main(): |
| """Main CLI interface""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Helion-OSC Sharded Model Loader") |
| parser.add_argument( |
| "model_path", |
| type=str, |
| help="Path to inference directory" |
| ) |
| parser.add_argument( |
| "--action", |
| choices=["inspect", "verify", "load"], |
| default="inspect", |
| help="Action to perform" |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to load model to" |
| ) |
| parser.add_argument( |
| "--low-memory", |
| action="store_true", |
| help="Use low memory mode" |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.action == "inspect": |
| inspect_model(args.model_path) |
| |
| elif args.action == "verify": |
| loader = ShardedModelLoader(args.model_path) |
| loader.verify_shards() |
| |
| elif args.action == "load": |
| logger.info("Loading full model...") |
| weights, metadata = load_full_model( |
| args.model_path, |
| device=args.device, |
| low_memory=args.low_memory |
| ) |
| logger.info(f"Successfully loaded {len(weights)} weight tensors") |
| logger.info(f"Model ready on {args.device}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |