Spaces:
Sleeping
Sleeping
File size: 2,692 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | import logging
from pathlib import Path
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
from .prediction_helper import (
ResnetCarDamagePredictor,
FusionCarDamagePredictor,
)
logger = logging.getLogger(__name__)
MODEL_CONFIG = {
"resnet": {
"repo_id": "junaid17/car-damage-classifier",
"filename": "car-damage-classifier.pt",
},
"fusion": {
"repo_id": "junaid17/best_fusion_model_fp16",
"filename": "best_fusion_model_fp16.pt",
},
"yolo": {
"repo_id": "junaid17/Yolo_Model",
"filename": "damage_detector.pt",
},
}
def get_checkpoint_path(model_key: str) -> Path:
if model_key not in MODEL_CONFIG:
raise ValueError(f"Unknown model key: {model_key}")
config = MODEL_CONFIG[model_key]
try:
logger.info(f"Fetching {model_key} model from Hugging Face Hub...")
logger.info(f"Repo: {config['repo_id']}")
logger.info(f"File: {config['filename']}")
local_path = hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename"],
)
logger.info(f"{model_key} model available at: {local_path}")
return Path(local_path)
except Exception as e:
logger.exception(f"Failed to fetch {model_key} model.")
raise RuntimeError(f"Failed to load {model_key} checkpoint: {str(e)}")
class ModelLoader:
def __init__(self):
logger.info("Initializing ModelLoader...")
def get_model_path(self, model_key: str) -> Path:
return get_checkpoint_path(model_key)
def initialize_models(class_map):
logger.info("Starting model initialization...")
try:
resnet_path = get_checkpoint_path("resnet")
fusion_path = get_checkpoint_path("fusion")
yolo_path = get_checkpoint_path("yolo")
logger.info("Initializing ResNet predictor...")
resnet_predictor = ResnetCarDamagePredictor(
checkpoint_path=resnet_path,
class_map=class_map
)
logger.info("Initializing Fusion predictor...")
fusion_predictor = FusionCarDamagePredictor(
checkpoint_path=fusion_path,
class_map=class_map
)
logger.info("Initializing YOLO model...")
yolo_model = YOLO(str(yolo_path))
logger.info("All models initialized successfully.")
return resnet_predictor, fusion_predictor, yolo_model
except Exception as e:
logger.exception("Model initialization failed.")
raise RuntimeError(f"Model initialization failed: {str(e)}") |