| import os |
| import sys |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import einops |
| from PIL import Image |
|
|
| parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| gmflow_dir = os.path.join(parent_dir, 'deps/gmflow') |
| sys.path.insert(0, gmflow_dir) |
|
|
| from GMFlow.gmflow.gmflow import GMFlow |
| from GMFlow.utils.utils import InputPadder |
|
|
|
|
| def coords_grid(b, h, w, homogeneous=False, device=None): |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) |
|
|
| stacks = [x, y] |
|
|
| if homogeneous: |
| ones = torch.ones_like(x) |
| stacks.append(ones) |
|
|
| grid = torch.stack(stacks, dim=0).float() |
|
|
| grid = grid[None].repeat(b, 1, 1, 1) |
|
|
| if device is not None: |
| grid = grid.to(device) |
|
|
| return grid |
|
|
|
|
| def bilinear_sample(img, |
| sample_coords, |
| mode='bilinear', |
| padding_mode='zeros', |
| return_mask=False): |
| |
| |
| if sample_coords.size(1) != 2: |
| sample_coords = sample_coords.permute(0, 3, 1, 2) |
|
|
| b, _, h, w = sample_coords.shape |
|
|
| |
| x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 |
| y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 |
|
|
| grid = torch.stack([x_grid, y_grid], dim=-1) |
|
|
| img = F.grid_sample(img, |
| grid, |
| mode=mode, |
| padding_mode=padding_mode, |
| align_corners=True) |
|
|
| if return_mask: |
| mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( |
| y_grid <= 1) |
|
|
| return img, mask |
|
|
| return img |
|
|
|
|
| def flow_warp(feature, |
| flow, |
| mask=False, |
| mode='bilinear', |
| padding_mode='zeros'): |
| b, c, h, w = feature.size() |
| assert flow.size(1) == 2 |
|
|
| grid = coords_grid(b, h, w).to(flow.device) + flow |
|
|
| return bilinear_sample(feature, |
| grid, |
| mode=mode, |
| padding_mode=padding_mode, |
| return_mask=mask) |
|
|
|
|
| def forward_backward_consistency_check(fwd_flow, |
| bwd_flow, |
| alpha=0.01, |
| beta=0.5, |
| return_confidence=False): |
| |
| |
| |
| assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 |
| assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 |
| flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, |
| dim=1) |
|
|
| warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) |
| warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) |
|
|
| diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) |
| diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) |
|
|
| threshold = alpha * flow_mag + beta |
|
|
| if return_confidence: |
| |
| |
| fwd_occ = torch.exp(-diff_fwd) |
| bwd_occ = torch.exp(-diff_bwd) |
| |
| |
| else: |
| fwd_occ = (diff_fwd > threshold).float() |
| bwd_occ = (diff_bwd > threshold).float() |
|
|
|
|
| return fwd_occ, bwd_occ |
|
|
|
|
| @torch.no_grad() |
| def get_warped_and_mask(flow_model, |
| image1, |
| image2, |
| image3=None, |
| pixel_consistency=False, |
| return_confidence=False): |
| if image3 is None: |
| image3 = image1[None] |
| padder = InputPadder(image1.shape, padding_factor=16) |
| |
| image1, image2 = padder.pad(image1[None], image2[None]) |
| results_dict = flow_model(image1, |
| image2, |
| attn_splits_list=[2], |
| corr_radius_list=[-1], |
| prop_radius_list=[-1], |
| pred_bidir_flow=True) |
| flow_pr = results_dict['flow_preds'][-1] |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| fwd_occ, bwd_occ = forward_backward_consistency_check( |
| fwd_flow, bwd_flow) |
| |
| if pixel_consistency: |
| warped_image1 = flow_warp(image1, padder.pad(bwd_flow)[0]) |
| bwd_occ = torch.clamp( |
| padder.pad(bwd_occ)[0] + |
| (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, |
| 1) |
| warped_results = flow_warp(image3, bwd_flow) |
| if return_confidence: |
| fwd_err, bwd_err = forward_backward_consistency_check( |
| fwd_flow, bwd_flow, return_confidence=return_confidence) |
| return warped_results, bwd_occ, bwd_flow, bwd_err |
| |
| return warped_results, bwd_occ, bwd_flow |
|
|
|
|
| class FlowCalc(): |
|
|
| def __init__(self, model_path='./weights/gmflow_sintel-0c07dcb3.pth'): |
| flow_model = GMFlow( |
| feature_channels=128, |
| num_scales=1, |
| upsample_factor=8, |
| num_head=1, |
| attention_type='swin', |
| ffn_dim_expansion=4, |
| num_transformer_layers=6, |
| ).to('cuda') |
|
|
| checkpoint = torch.load(model_path, |
| map_location=lambda storage, loc: storage) |
| weights = checkpoint['model'] if 'model' in checkpoint else checkpoint |
| flow_model.load_state_dict(weights, strict=False) |
| flow_model.eval() |
| self.model = flow_model |
|
|
| @torch.no_grad() |
| def get_flow(self, image1, image2, save_path=None): |
|
|
| if save_path is not None and os.path.exists(save_path): |
| bwd_flow = read_flow(save_path) |
| return bwd_flow |
|
|
| image1 = torch.from_numpy(image1).permute(2, 0, 1).float() |
| image2 = torch.from_numpy(image2).permute(2, 0, 1).float() |
| padder = InputPadder(image1.shape, padding_factor=8) |
| image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) |
| results_dict = self.model(image1, |
| image2, |
| attn_splits_list=[2], |
| corr_radius_list=[-1], |
| prop_radius_list=[-1], |
| pred_bidir_flow=True) |
| flow_pr = results_dict['flow_preds'][-1] |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) |
| fwd_occ, bwd_occ = forward_backward_consistency_check( |
| fwd_flow, bwd_flow) |
| if save_path is not None: |
| flow_np = bwd_flow.cpu().numpy() |
| np.save(save_path, flow_np) |
| mask_path = os.path.splitext(save_path)[0] + '.png' |
| bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( |
| torch.long).numpy() * 255 |
| cv2.imwrite(mask_path, bwd_occ) |
|
|
| return bwd_flow |
|
|
| @torch.no_grad() |
| def get_mask(self, image1, image2, save_path=None): |
|
|
| if save_path is not None: |
| mask_path = os.path.splitext(save_path)[0] + '.png' |
| if os.path.exists(mask_path): |
| return read_mask(mask_path) |
|
|
| image1 = torch.from_numpy(image1).permute(2, 0, 1).float() |
| image2 = torch.from_numpy(image2).permute(2, 0, 1).float() |
| padder = InputPadder(image1.shape, padding_factor=8) |
| image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) |
| results_dict = self.model(image1, |
| image2, |
| attn_splits_list=[2], |
| corr_radius_list=[-1], |
| prop_radius_list=[-1], |
| pred_bidir_flow=True) |
| flow_pr = results_dict['flow_preds'][-1] |
| fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) |
| bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) |
| fwd_occ, bwd_occ = forward_backward_consistency_check( |
| fwd_flow, bwd_flow) |
| if save_path is not None: |
| flow_np = bwd_flow.cpu().numpy() |
| np.save(save_path, flow_np) |
| mask_path = os.path.splitext(save_path)[0] + '.png' |
| bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( |
| torch.long).numpy() * 255 |
| cv2.imwrite(mask_path, bwd_occ) |
|
|
| return bwd_occ |
|
|
| def warp(self, img, flow, mode='bilinear'): |
| expand = False |
| if len(img.shape) == 2: |
| expand = True |
| img = np.expand_dims(img, 2) |
|
|
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) |
| dtype = img.dtype |
| img = img.to(torch.float) |
| res = flow_warp(img, flow, mode=mode) |
| res = res.to(dtype) |
| res = res[0].cpu().permute(1, 2, 0).numpy() |
| if expand: |
| res = res[:, :, 0] |
| return res |
|
|
|
|
| def read_flow(save_path): |
| flow_np = np.load(save_path) |
| bwd_flow = torch.from_numpy(flow_np) |
| return bwd_flow |
|
|
|
|
| def read_mask(save_path): |
| mask_path = os.path.splitext(save_path)[0] + '.png' |
| mask = cv2.imread(mask_path) |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| return mask |
|
|
|
|
| flow_calc = FlowCalc() |
|
|