File size: 2,692 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
from pathlib import Path
from huggingface_hub import hf_hub_download
from ultralytics import YOLO

from .prediction_helper import (
    ResnetCarDamagePredictor,
    FusionCarDamagePredictor,
)

logger = logging.getLogger(__name__)

MODEL_CONFIG = {
    "resnet": {
        "repo_id": "junaid17/car-damage-classifier",
        "filename": "car-damage-classifier.pt",
    },
    "fusion": {
        "repo_id": "junaid17/best_fusion_model_fp16",
        "filename": "best_fusion_model_fp16.pt",
    },
    "yolo": {
        "repo_id": "junaid17/Yolo_Model",
        "filename": "damage_detector.pt",
    },
}


def get_checkpoint_path(model_key: str) -> Path:
    if model_key not in MODEL_CONFIG:
        raise ValueError(f"Unknown model key: {model_key}")

    config = MODEL_CONFIG[model_key]

    try:
        logger.info(f"Fetching {model_key} model from Hugging Face Hub...")
        logger.info(f"Repo: {config['repo_id']}")
        logger.info(f"File: {config['filename']}")

        local_path = hf_hub_download(
            repo_id=config["repo_id"],
            filename=config["filename"],
        )

        logger.info(f"{model_key} model available at: {local_path}")

        return Path(local_path)

    except Exception as e:
        logger.exception(f"Failed to fetch {model_key} model.")
        raise RuntimeError(f"Failed to load {model_key} checkpoint: {str(e)}")


class ModelLoader:
    def __init__(self):
        logger.info("Initializing ModelLoader...")

    def get_model_path(self, model_key: str) -> Path:
        return get_checkpoint_path(model_key)


def initialize_models(class_map):
    logger.info("Starting model initialization...")

    try:
        resnet_path = get_checkpoint_path("resnet")
        fusion_path = get_checkpoint_path("fusion")
        yolo_path = get_checkpoint_path("yolo")

        logger.info("Initializing ResNet predictor...")
        resnet_predictor = ResnetCarDamagePredictor(
            checkpoint_path=resnet_path,
            class_map=class_map
        )

        logger.info("Initializing Fusion predictor...")
        fusion_predictor = FusionCarDamagePredictor(
            checkpoint_path=fusion_path,
            class_map=class_map
        )

        logger.info("Initializing YOLO model...")
        yolo_model = YOLO(str(yolo_path))

        logger.info("All models initialized successfully.")

        return resnet_predictor, fusion_predictor, yolo_model

    except Exception as e:
        logger.exception("Model initialization failed.")
        raise RuntimeError(f"Model initialization failed: {str(e)}")