| |
|
| |
|
| | import torch |
| | from transformers import OwlViTProcessor, OwlViTForObjectDetection |
| |
|
| | from .model import OwlViTForClassification |
| |
|
| | def load_xclip(device: str = "cuda:0", |
| | n_classes: int = 183, |
| | use_teacher_logits: bool = False, |
| | custom_box_head: bool = False, |
| | model_path: str = 'data/models/peeb_pretrain.pt', |
| | ): |
| | |
| | owlvit_det_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") |
| | owlvit_det_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) |
| |
|
| | |
| | mean = [0.48168647, 0.49244233, 0.42851609] |
| | std = [0.18656386, 0.18614962, 0.19659419] |
| | owlvit_det_processor.image_processor.image_mean = mean |
| | owlvit_det_processor.image_processor.image_std = std |
| | |
| | |
| | weight_dict = {"loss_ce": 0, "loss_bbox": 0, "loss_giou": 0, |
| | "loss_sym_box_label": 0, "loss_xclip": 0} |
| | model = OwlViTForClassification(owlvit_det_model=owlvit_det_model, num_classes=n_classes, device=device, weight_dict=weight_dict, logits_from_teacher=use_teacher_logits, custom_box_head=custom_box_head) |
| | if model_path is not None: |
| | ckpt = torch.load(model_path, map_location='cpu') |
| | model.load_state_dict(ckpt, strict=False) |
| | model.to(device) |
| | return model, owlvit_det_processor |