| | """ |
| | Contains various utility functions for PyTorch model training and saving. |
| | """ |
| | import torch |
| | from pathlib import Path |
| |
|
| | def save_model(model: torch.nn.Module, |
| | target_dir: str, |
| | model_name: str): |
| | """Saves a PyTorch model to a target directory. |
| | |
| | Args: |
| | model: A target PyTorch model to save. |
| | target_dir: A directory for saving the model to. |
| | model_name: A filename for the saved model. Should include |
| | either ".pth" or ".pt" as the file extension. |
| | |
| | Example usage: |
| | save_model(model=model_0, |
| | target_dir="models", |
| | model_name="05_going_modular_tingvgg_model.pth") |
| | """ |
| | |
| | target_dir_path = Path(target_dir) |
| | target_dir_path.mkdir(parents=True, |
| | exist_ok=True) |
| |
|
| | |
| | assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" |
| | model_save_path = target_dir_path / model_name |
| |
|
| | |
| | print(f"[INFO] Saving model to: {model_save_path}") |
| | torch.save(obj=model.state_dict(), |
| | f=model_save_path) |