| import torch |
| import einops |
| import torchvision |
| import torch.nn as nn |
| from typing import List, Tuple |
|
|
|
|
| class MultiviewStack(nn.Module): |
| def __init__( |
| self, |
| encoders: List[nn.Module], |
| normalizations: List[Tuple[List, List]], |
| output_dim: int, |
| ): |
| super().__init__() |
| self.encoders = nn.ModuleList(encoders) |
| self.normalizations = [] |
| for mean, std in normalizations: |
| self.normalizations.append( |
| torchvision.transforms.Normalize(mean=mean, std=std) |
| ) |
|
|
| def forward(self, x): |
| orig_shape = x.shape |
| x = einops.rearrange(x, "... V C H W -> (...) V C H W") |
| outputs = [] |
| for i, encoder in enumerate(self.encoders): |
| this_view = x[:, i] |
| this_view = self.normalizations[i](this_view) |
| outputs.append(encoder(this_view)) |
| out = torch.stack(outputs, dim=-1) |
| out = out.reshape(*orig_shape[:-3], -1) |
| return out |
|
|