| | |
| |
|
| | from typing import Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.fft import irfftn, rfftn |
| | from torch.nn.functional import grid_sample, log_softmax, pad |
| |
|
| | from .metrics import angle_error |
| | from .utils import make_grid, rotmat2d |
| | from torchvision.transforms.functional import rotate |
| |
|
| | class UAVTemplateSamplerFast(torch.nn.Module): |
| | def __init__(self, num_rotations,w=128,optimize=True): |
| | super().__init__() |
| |
|
| | h, w = w,w |
| | grid_xy = make_grid( |
| | w=w, |
| | h=h, |
| | step_x=1, |
| | step_y=1, |
| | orig_y=-h//2, |
| | orig_x=-h//2, |
| | y_up=True, |
| | ).cuda() |
| |
|
| | if optimize: |
| | assert (num_rotations % 4) == 0 |
| | angles = torch.arange( |
| | 0, 90, 90 / (num_rotations // 4) |
| | ).cuda() |
| | else: |
| | angles = torch.arange( |
| | 0, 360, 360 / num_rotations, device=grid_xz_bev.device |
| | ) |
| | rotmats = rotmat2d(angles / 180 * np.pi) |
| | grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) |
| |
|
| | grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor( |
| | [1, -1] |
| | ) |
| | grid_ij_rot = grid_ij_rot |
| | grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 |
| |
|
| | self.optimize = optimize |
| | self.num_rots = num_rotations |
| | self.register_buffer("angles", angles, persistent=False) |
| | self.register_buffer("grid_norm", grid_norm, persistent=False) |
| |
|
| | def forward(self, image_bev): |
| | grid = self.grid_norm |
| | b, c = image_bev.shape[:2] |
| | n, h, w = grid.shape[:3] |
| | grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) |
| | image = ( |
| | image_bev[:, None] |
| | .repeat_interleave(n, 1) |
| | .reshape(b * n, *image_bev.shape[1:]) |
| | ) |
| | |
| | kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( |
| | b, n, c, h, w |
| | ) |
| |
|
| | if self.optimize: |
| | kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] |
| | kernels = torch.cat([kernels] + kernels_quad234, 1) |
| |
|
| | return kernels |
| | class UAVTemplateSampler(torch.nn.Module): |
| | def __init__(self, num_rotations): |
| | super().__init__() |
| |
|
| | self.num_rotations = num_rotations |
| |
|
| | def Template(self, input_features): |
| | |
| | num_angles = self.num_rotations |
| | |
| | input_shape = torch.tensor(input_features.shape) |
| | output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist() |
| | expanded_features = torch.zeros(output_shape,device=input_features.device) |
| |
|
| | |
| | rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1] |
| | |
| | |
| | rotated_features = [] |
| | |
| | for i in range(len(rotation_angles)): |
| | |
| | rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0) |
| | expanded_features[:, i, :, :, :] = rotated_feature |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | return expanded_features |
| | def forward(self, image_bev): |
| |
|
| | kernels=self.Template(image_bev) |
| |
|
| | return kernels |
| | class TemplateSampler(torch.nn.Module): |
| | def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True): |
| | super().__init__() |
| |
|
| | Δ = 1 / ppm |
| | h, w = grid_xz_bev.shape[:2] |
| | ksize = max(w, h * 2 + 1) |
| | radius = ksize * Δ |
| | grid_xy = make_grid( |
| | radius, |
| | radius, |
| | step_x=Δ, |
| | step_y=Δ, |
| | orig_y=(Δ - radius) / 2, |
| | orig_x=(Δ - radius) / 2, |
| | y_up=True, |
| | ) |
| |
|
| | if optimize: |
| | assert (num_rotations % 4) == 0 |
| | angles = torch.arange( |
| | 0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device |
| | ) |
| | else: |
| | angles = torch.arange( |
| | 0, 360, 360 / num_rotations, device=grid_xz_bev.device |
| | ) |
| | rotmats = rotmat2d(angles / 180 * np.pi) |
| | grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy) |
| |
|
| | grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor( |
| | [1, -1] |
| | ) |
| | grid_ij_rot = grid_ij_rot / Δ |
| | grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1 |
| |
|
| | self.optimize = optimize |
| | self.num_rots = num_rotations |
| | self.register_buffer("angles", angles, persistent=False) |
| | self.register_buffer("grid_norm", grid_norm, persistent=False) |
| |
|
| | def forward(self, image_bev): |
| | grid = self.grid_norm |
| | b, c = image_bev.shape[:2] |
| | n, h, w = grid.shape[:3] |
| | grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2) |
| | image = ( |
| | image_bev[:, None] |
| | .repeat_interleave(n, 1) |
| | .reshape(b * n, *image_bev.shape[1:]) |
| | ) |
| | kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape( |
| | b, n, c, h, w |
| | ) |
| |
|
| | if self.optimize: |
| | kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)] |
| | kernels = torch.cat([kernels] + kernels_quad234, 1) |
| |
|
| | return kernels |
| |
|
| |
|
| | def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"): |
| | if padding == "same": |
| | padding = [i // 2 for i in kernel.shape[-2:]] |
| | padding_signal = [p for p in padding[::-1] for _ in range(2)] |
| | signal = pad(signal, padding_signal, mode=padding_mode) |
| | assert signal.size(-1) % 2 == 0 |
| |
|
| | padding_kernel = [ |
| | pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)] |
| | ] |
| | kernel_padded = pad(kernel, padding_kernel) |
| |
|
| | signal_fr = rfftn(signal, dim=(-1, -2)) |
| | kernel_fr = rfftn(kernel_padded, dim=(-1, -2)) |
| |
|
| | kernel_fr.imag *= -1 |
| | output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr) |
| | output = irfftn(output_fr, dim=(-1, -2)) |
| |
|
| | crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [ |
| | slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1] |
| | ] |
| | output = output[crop_slices].contiguous() |
| |
|
| | return output |
| |
|
| |
|
| | class SparseMapSampler(torch.nn.Module): |
| | def __init__(self, num_rotations): |
| | super().__init__() |
| | angles = torch.arange(0, 360, 360 / self.conf.num_rotations) |
| | rotmats = rotmat2d(angles / 180 * np.pi) |
| | self.num_rotations = num_rotations |
| | self.register_buffer("rotmats", rotmats, persistent=False) |
| |
|
| | def forward(self, image_map, p2d_bev): |
| | h, w = image_map.shape[-2:] |
| | locations = make_grid(w, h, device=p2d_bev.device) |
| | p2d_candidates = torch.einsum( |
| | "kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev |
| | ) |
| | p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1) |
| | |
| |
|
| | p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1 |
| | valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1) |
| | value = grid_sample( |
| | image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear" |
| | ) |
| | value = value.reshape(image_map.shape[:2] + valid.shape[-4]) |
| | return valid, value |
| |
|
| |
|
| | def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False): |
| | |
| | volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular") |
| |
|
| | size = xy_grid.new_tensor(volume.shape[-3:-1][::-1]) |
| | xy_norm = xy_grid / (size - 1) |
| | angle_norm = (angle_grid / 360) % 1 |
| | grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1) |
| | grid_norm = grid * 2 - 1 |
| |
|
| | valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1) |
| | value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear") |
| |
|
| | |
| | |
| | if nearest_for_inf: |
| | value_nearest = grid_sample( |
| | volume_padded, grid_norm, align_corners=True, mode="nearest" |
| | ) |
| | value = torch.where(~torch.isfinite(value) & valid, value_nearest, value) |
| |
|
| | return value, valid |
| |
|
| |
|
| | def nll_loss_xyr(log_probs, xy, angle): |
| | log_prob, _ = sample_xyr( |
| | log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None] |
| | ) |
| | nll = -log_prob.reshape(-1) |
| | return nll |
| |
|
| |
|
| | def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None): |
| | *_, nx, ny, nr = log_probs.shape |
| | grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float) |
| | dx = (grid_x - xy[..., None, 0]) / sigma_xy |
| | grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float) |
| | dy = (grid_y - xy[..., None, 1]) / sigma_xy |
| | dr = ( |
| | torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float) |
| | - angle[..., None] |
| | ) % 360 |
| | dr = torch.minimum(dr, 360 - dr) / sigma_r |
| | diff = ( |
| | dx[..., None, :, None] ** 2 |
| | + dy[..., :, None, None] ** 2 |
| | + dr[..., None, None, :] ** 2 |
| | ) |
| | pdf = torch.exp(-diff / 2) |
| | if mask is not None: |
| | pdf.masked_fill_(~mask[..., None], 0) |
| | log_probs = log_probs.masked_fill(~mask[..., None], 0) |
| | pdf /= pdf.sum((-1, -2, -3), keepdim=True) |
| | return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3)) |
| |
|
| |
|
| | def log_softmax_spatial(x, dims=3): |
| | return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape) |
| |
|
| |
|
| | @torch.jit.script |
| | def argmax_xy(scores: torch.Tensor) -> torch.Tensor: |
| | indices = scores.flatten(-2).max(-1).indices |
| | width = scores.shape[-1] |
| | x = indices % width |
| | y = torch.div(indices, width, rounding_mode="floor") |
| | return torch.stack((x, y), -1) |
| |
|
| |
|
| | @torch.jit.script |
| | def expectation_xy(prob: torch.Tensor) -> torch.Tensor: |
| | h, w = prob.shape[-2:] |
| | grid = make_grid(float(w), float(h), device=prob.device).to(prob) |
| | return torch.einsum("...hw,hwd->...d", prob, grid) |
| |
|
| |
|
| | @torch.jit.script |
| | def expectation_xyr( |
| | prob: torch.Tensor, covariance: bool = False |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | h, w, num_rotations = prob.shape[-3:] |
| | x, y = torch.meshgrid( |
| | [ |
| | torch.arange(w, device=prob.device, dtype=prob.dtype), |
| | torch.arange(h, device=prob.device, dtype=prob.dtype), |
| | ], |
| | indexing="xy", |
| | ) |
| | grid_xy = torch.stack((x, y), -1) |
| | xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy) |
| |
|
| | angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype) |
| | angles = angles * 2 * np.pi |
| | grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1) |
| | cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs) |
| | angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0]) |
| | angle = (angle * 180 / np.pi) % 360 |
| |
|
| | if covariance: |
| | xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy) |
| | xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean) |
| | else: |
| | xy_cov = None |
| |
|
| | xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1) |
| | return xyr_mean, xy_cov |
| |
|
| |
|
| | @torch.jit.script |
| | def argmax_xyr(scores: torch.Tensor) -> torch.Tensor: |
| | indices = scores.flatten(-3).max(-1).indices |
| | width, num_rotations = scores.shape[-2:] |
| | wr = width * num_rotations |
| | y = torch.div(indices, wr, rounding_mode="floor") |
| | x = torch.div(indices % wr, num_rotations, rounding_mode="floor") |
| | angle_index = indices % num_rotations |
| | angle = angle_index * 360 / num_rotations |
| | xyr = torch.stack((x, y, angle), -1) |
| | return xyr |
| |
|
| |
|
| | @torch.jit.script |
| | def mask_yaw_prior( |
| | scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int |
| | ) -> torch.Tensor: |
| | step = 360 / num_rotations |
| | step_2 = step / 2 |
| | angles = torch.arange(step_2, 360 + step_2, step, device=scores.device) |
| | yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1) |
| | rot_mask = angle_error(angles, yaw_init) < yaw_range |
| | return scores.masked_fill_(~rot_mask[:, None, None], -np.inf) |
| |
|
| |
|
| | def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False): |
| | grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob) |
| | dist = torch.sum((grid - uv_gps) ** 2, -1) |
| | sigma_pixel = sigma * ppm |
| | if gaussian: |
| | gps_log_prob = -1 / 2 * dist / sigma_pixel**2 |
| | else: |
| | gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf) |
| | log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1)) |
| | return log_prob_fused |
| |
|