| import ClassUtils |
| import LoadUtils |
|
|
| import torch |
| import torchvision |
| import torchvision.models as models |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import random |
|
|
| import warnings |
|
|
| |
| warnings.filterwarnings( |
| action='ignore', |
| category=DeprecationWarning, |
| module=r'.*' |
| ) |
|
|
| vgg16_state_path = "VGG16_Full_State_Dict.pth" |
| |
| mobileNet_path = "MobileNetV3_state_dict_big_train.pth" |
| data_path = "zebra_annotations/classification_data" |
|
|
| classify = None |
| transform = None |
|
|
| |
| def load_vgg_classifier(state_dict_path): |
| |
| model = models.vgg16() |
|
|
| |
| model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2) |
| state_dict = torch.load(state_dict_path, weights_only=True) |
| model.load_state_dict(state_dict) |
|
|
| model.eval() |
|
|
| return model |
|
|
| |
| |
| |
| def partial_vgg_load(classifier_state_dict_path): |
| model = models.vgg16(weights=models.VGG16_Weights.DEFAULT) |
|
|
| model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2) |
| model.classifier.load_state_dict(classifier_state_dict_path) |
|
|
| model.eval() |
|
|
| return model |
|
|
| |
| def load_resnet_classifier(state_dict_path): |
| |
| resnet = models.resnet18(pretrained=True) |
| resnet.fc = torch.nn.Linear(resnet.fc.in_features, 1) |
|
|
| state_dict = torch.load(state_dict_path, weights_only=True) |
| resnet.load_state_dict(state_dict) |
| |
| resnet.eval() |
| return resnet |
|
|
| |
| |
| def load_mobileNet_classifier(state_dict_path): |
| |
| model = models.mobilenet_v3_small() |
| model.classifier[3] = torch.nn.Linear(model.classifier[3].in_features, 2) |
|
|
| state_dict = torch.load(state_dict_path, weights_only=True) |
| model.load_state_dict(state_dict) |
|
|
| model.eval() |
| return model |
|
|
| |
| |
|
|
| classify = load_mobileNet_classifier(mobileNet_path) |
| transform = ClassUtils.mob3_transform |
|
|
|
|
|
|
| |
| |
| def infer(image, infer_model=classify, infer_transform=transform): |
|
|
| |
| |
| |
| |
| if infer_model is None or infer_transform is None: |
| raise TypeError("Error: The inference classes have not been initialised properly.") |
| if not torch.is_tensor(image): |
| image = infer_transform(image) |
| |
| |
| if len(image.shape) <= 3: |
| image = image.unsqueeze(0) |
|
|
| logit_pred = infer_model(image) |
|
|
| probs = 1 / (1 + np.exp(-logit_pred.detach().numpy())) |
| |
| return probs |
|
|
|
|
| |
| |
| def PIL_infer(image, threshold=0.35): |
| tensor_im = torchvision.transforms.functional.pil_to_tensor(image).float()/ 255 |
| prediction = infer(tensor_im) |
| classification = prediction[0][0] > threshold |
| return classification |
|
|
| |
| def infer_and_display(image, threshold, actual_label, onlyWrong=False): |
| probability = infer(image) |
| prediction = probability > threshold |
| is_correct = (actual_label[0] == 1) == prediction |
|
|
| if onlyWrong and is_correct: |
| return prediction |
| |
| plt.imshow(torch.permute(image, (1, 2, 0)).detach().numpy()) |
| plt.title(f"Prediction: {prediction[0][0]} with confidence {probability[0][0]}%, Actual: {actual_label[0] == 1}") |
| plt.axis("off") |
| plt.show() |
|
|
| return probability |
|
|
|
|
| |
| def example_init(examples=20, display=True): |
| dataset = ClassUtils.CrosswalkDataset(data_path) |
| |
| random_points = [random.randint(0, len(dataset)-1) for i in range(examples)] |
| correct, incorrect, falsepos, falseneg = 0, 0, 0, 0 |
| for point in random_points: |
| image, label = dataset[point] |
|
|
| class_guess = [0, 1] |
| if infer(image)[0][0] > 0.5: |
| class_guess = [1, 0] |
| if class_guess == label.tolist(): |
| correct += 1 |
| else: |
| if class_guess[0]: |
| falsepos += 1 |
| else: |
| falseneg += 1 |
| incorrect += 1 |
| |
| if display: |
| print(f"Prediction of {infer_and_display(image, 0.4, label)}% of a crosswalk (Crosswalk: {label[0]==1})") |
| print(f"correct: {correct}, incorrect: {incorrect}, of which false positives were {falsepos} and false negatives were {falseneg}") |
|
|
| if __name__ == "__main__": |
| example_init(examples=200,display=False) |
|
|
| else: |
| print(f"Module: [{__name__}] has been loaded") |
|
|
|
|
|
|