| | import gc |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from einops import repeat, rearrange |
| | from vidtome import merge |
| | from utils.flow_utils import flow_warp, coords_grid |
| |
|
| | |
| |
|
| |
|
| | def calc_mean_std(feat, eps=1e-5): |
| | |
| | size = feat.size() |
| | assert (len(size) == 4) |
| | N, C = size[:2] |
| | feat_var = feat.view(N, C, -1).var(dim=2) + eps |
| | feat_std = feat_var.sqrt().view(N, C, 1, 1) |
| | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
| | return feat_mean, feat_std |
| |
|
| |
|
| | class AttentionControl(): |
| |
|
| | def __init__(self, |
| | warp_period=(0.0, 0.0), |
| | merge_period=(0.0, 0.0), |
| | merge_ratio=(0.3, 0.3), |
| | ToMe_period=(0.0, 1.0), |
| | mask_period=(0.0, 0.0), |
| | cross_period=(0.0, 0.0), |
| | ada_period=(0.0, 0.0), |
| | inner_strength=1.0, |
| | loose_cfatnn=False, |
| | flow_merge=True, |
| | ): |
| | |
| | self.cur_frame_idx = 0 |
| |
|
| | self.step_store = self.get_empty_store() |
| | self.cur_step = 0 |
| | self.total_step = 0 |
| | self.cur_index = 0 |
| | self.init_store = False |
| | self.restore = False |
| | self.update = False |
| | self.flow = None |
| | self.mask = None |
| | self.cldm = None |
| | self.decoded_imgs = [] |
| | self.restorex0 = True |
| | self.updatex0 = False |
| | self.inner_strength = inner_strength |
| | self.cross_period = cross_period |
| | self.mask_period = mask_period |
| | self.ada_period = ada_period |
| | self.warp_period = warp_period |
| | self.ToMe_period = ToMe_period |
| | self.merge_period = merge_period |
| | self.merge_ratio = merge_ratio |
| | self.keyframe_idx = 0 |
| | self.flow_merge = flow_merge |
| | self.distances = {} |
| | self.flow_correspondence = {} |
| | self.non_pad_ratio = (1.0, 1.0) |
| | self.up_resolution = 1280 if loose_cfatnn else 1281 |
| |
|
| | @staticmethod |
| | def get_empty_store(): |
| | return { |
| | 'first': [], |
| | 'previous': [], |
| | 'x0_previous': [], |
| | 'first_ada': [], |
| | 'pre_x0': [], |
| | "pre_keyframe_lq": None, |
| | "flows": None, |
| | "occ_masks": None, |
| | "flow_confids": None, |
| | "merge": None, |
| | "unmerge": None, |
| | "corres_scores": None, |
| | "flows2": None, |
| | "flow_confids2": None, |
| | } |
| |
|
| | def forward(self, context, is_cross: bool, place_in_unet: str): |
| | cross_period = (self.total_step * self.cross_period[0], |
| | self.total_step * self.cross_period[1]) |
| | if not is_cross and place_in_unet == 'up' and context.shape[ |
| | 2] < self.up_resolution: |
| | if self.init_store: |
| | self.step_store['first'].append(context.detach()) |
| | self.step_store['previous'].append(context.detach()) |
| | if self.update: |
| | tmp = context.clone().detach() |
| | if self.restore and self.cur_step >= cross_period[0] and \ |
| | self.cur_step <= cross_period[1]: |
| | |
| | |
| | |
| | |
| | context = self.step_store['previous'][self.cur_index].clone() |
| | if self.update: |
| | self.step_store['previous'][self.cur_index] = tmp |
| | self.cur_index += 1 |
| | |
| | |
| | return context |
| |
|
| | def update_x0(self, x0, cur_frame=0): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.restorex0: |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.cur_step >= self.total_step * self.warp_period[ |
| | 0] and self.cur_step < int(self.total_step * self.warp_period[1]): |
| | |
| | |
| | mid = x0.shape[0] // 2 |
| | if len(self.step_store["pre_x0"]) == int(self.total_step * self.warp_period[1]): |
| | print(f"[INFO] keyframe latent warping @ step {self.cur_step}...") |
| | x0[mid] = (1 - self.step_store["occ_masks"][mid]) * x0[mid] + \ |
| | flow_warp(self.step_store["pre_x0"][self.cur_step][None], self.step_store["flows"][mid], mode='nearest')[0] * self.step_store["occ_masks"][mid] |
| | |
| | print(f"[INFO] local latent warping @ step {self.cur_step}...") |
| | for i in range(x0.shape[0]): |
| | if i == mid: |
| | continue |
| | x0[i] = (1 - self.step_store["occ_masks"][i]) * x0[i] + \ |
| | flow_warp(x0[mid][None], self.step_store["flows"][i], mode='nearest')[0] * self.step_store["occ_masks"][i] |
| | |
| | |
| | |
| | if len(self.step_store["pre_x0"]) < int(self.total_step * self.warp_period[1]): |
| | self.step_store['pre_x0'].append(x0[mid]) |
| | else: |
| | self.step_store['pre_x0'][self.cur_step] = x0[mid] |
| |
|
| | return x0 |
| |
|
| | def merge_x0(self, x0, merge_ratio): |
| | |
| | if self.cur_step >= self.total_step * self.merge_period[0] and \ |
| | self.cur_step < int(self.total_step * self.merge_period[1]): |
| | print(f"[INFO] latent merging @ step {self.cur_step}...") |
| |
|
| | B, C, H, W = x0.shape |
| | non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio |
| | padding_size_w = W - int(W * non_pad_ratio_w) |
| | padding_size_h = H - int(H * non_pad_ratio_h) |
| | non_pad_w = W - padding_size_w |
| | non_pad_h = H - padding_size_h |
| | padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) |
| | if padding_size_w: |
| | padding_mask[:, -padding_size_w:] = 1 |
| | if padding_size_h: |
| | padding_mask[-padding_size_h:, :] = 1 |
| | padding_mask = rearrange(padding_mask, 'h w -> (h w)') |
| | |
| | idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) |
| | non_pad_idx = idx_buffer[None, ~padding_mask, None] |
| | del idx_buffer, padding_mask |
| | x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) |
| | x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) |
| | |
| | |
| |
|
| | |
| | |
| | import copy |
| | flows = copy.deepcopy(self.step_store["flows"]) |
| | for i in range(B): |
| | if flows[i] is not None: |
| | flows[i] = flows[i][:, :, :non_pad_h, :non_pad_w] |
| | |
| | |
| | |
| | x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') |
| | m, u, ret_dict = merge.bipartite_soft_matching_randframe( |
| | x_non_pad, B, merge_ratio, 0, target_stride=B, |
| | H=H, |
| | flow=flows, |
| | flow_confid=self.step_store["flow_confids"], |
| | ) |
| | x_non_pad = u(m(x_non_pad)) |
| | |
| | x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', b=B) |
| | |
| | |
| | |
| | x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) |
| | x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) |
| | |
| | |
| | return x0 |
| | |
| | def merge_x0_scores(self, x0, merge_ratio, merge_mode="replace"): |
| | |
| | |
| | if self.cur_step >= self.total_step * self.merge_period[0] and \ |
| | self.cur_step < int(self.total_step * self.merge_period[1]): |
| | print(f"[INFO] latent merging @ step {self.cur_step}...") |
| |
|
| | B, C, H, W = x0.shape |
| | non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio |
| | padding_size_w = W - int(W * non_pad_ratio_w) |
| | padding_size_h = H - int(H * non_pad_ratio_h) |
| | padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) |
| | if padding_size_w: |
| | padding_mask[:, -padding_size_w:] = 1 |
| | if padding_size_h: |
| | padding_mask[-padding_size_h:, :] = 1 |
| | padding_mask = rearrange(padding_mask, 'h w -> (h w)') |
| | |
| | idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) |
| | non_pad_idx = idx_buffer[None, ~padding_mask, None] |
| | x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) |
| | x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) |
| | x_non_pad_A, x_non_pad_N = x_non_pad.shape[1], x_non_pad.shape[1] * B |
| | mid = B // 2 |
| | |
| | x_non_pad_ = x_non_pad.clone() |
| | x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') |
| | |
| |
|
| | idx_buffer = torch.arange(x_non_pad_N, device=x0.device, dtype=torch.int64) |
| | randf = torch.tensor(B // 2, dtype=torch.int).to(x0.device) |
| | |
| | dst_select = ((torch.div(idx_buffer, x_non_pad_A, rounding_mode='floor')) % B == randf).to(torch.bool) |
| | |
| | a_idx = idx_buffer[None, ~dst_select, None] |
| | b_idx = idx_buffer[None, dst_select, None] |
| | del idx_buffer, padding_mask |
| | num_dst = b_idx.shape[1] |
| | |
| | b = 1 |
| | src = torch.gather(x_non_pad, dim=1, index=a_idx.expand(b, x_non_pad_N - num_dst, C)) |
| | tar = torch.gather(x_non_pad, dim=1, index=b_idx.expand(b, num_dst, C)) |
| | |
| | |
| | |
| | |
| | |
| | flow_src_idx = self.flow_correspondence[H][0] |
| | flow_tar_idx = self.flow_correspondence[H][1] |
| | flow_confid = self.step_store["flow_confids"][:mid] + self.step_store["flow_confids"][mid+1:] |
| | flow_confid = torch.cat(flow_confid, dim=0) |
| | flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') |
| | scores = F.normalize(self.step_store["corres_scores"], p=2, dim=-1) |
| |
|
| | flow_confid -= (torch.max(flow_confid) - torch.max(scores)) |
| |
|
| | |
| | |
| | |
| | |
| | scores[:, flow_src_idx[0, :, 0], flow_tar_idx[0, :, 0]] += (flow_confid[:, flow_src_idx[0, :, 0]] * 0.3) |
| | |
| | |
| | |
| |
|
| | |
| | r = min(src.shape[1], int(src.shape[1] * merge_ratio)) |
| | node_max, node_idx = scores.max(dim=-1) |
| | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
| | unm_idx = edge_idx[..., r:, :] |
| | src_idx = edge_idx[..., :r, :] |
| | tar_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx) |
| | unm = torch.gather(src, dim=-2, index=unm_idx.expand(-1, -1, C)) |
| | if merge_mode != "replace": |
| | src = torch.gather(src, dim=-2, index=src_idx.expand(-1, -1, C)) |
| | |
| | tar = tar.scatter_reduce(-2, tar_idx.expand(-1, -1, C), |
| | src, reduce=merge_mode, include_self=True) |
| | |
| | |
| |
|
| | |
| | |
| | src = torch.gather(tar, dim=-2, index=tar_idx.expand(-1, -1, C)) |
| | |
| | |
| | |
| | x_non_pad.scatter_(dim=-2, index=b_idx.expand(b, -1, C), src=tar) |
| | |
| | x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), |
| | dim=1, index=unm_idx).expand(-1, -1, C), src=unm) |
| | |
| | x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), |
| | dim=1, index=src_idx).expand(-1, -1, C), src=src) |
| |
|
| | x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', a=x_non_pad_A) |
| | x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) |
| | x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) |
| | |
| | return x0 |
| |
|
| | def set_distance(self, B, H, W, radius, device): |
| | y, x = torch.meshgrid(torch.arange(H), torch.arange(W)) |
| | coords = torch.stack((y, x), dim=-1).float().to(device) |
| | coords = rearrange(coords, 'h w c -> (h w) c') |
| |
|
| | |
| | distances = torch.cdist(coords, coords) |
| | |
| | radius = 1 if radius == 0 else radius |
| | |
| | distances //= radius |
| | distances = torch.exp(-distances) |
| | |
| | distances = repeat(distances, 'h a -> 1 (b h) a', b=B) |
| | self.distances[H] = distances |
| | |
| | def set_flow_correspondence(self, B, H, W, key_idx, flow_confid, flow): |
| |
|
| | if len(flow) != B - 1: |
| | flow_confid = flow_confid[:key_idx] + flow_confid[key_idx+1:] |
| | flow = flow[:key_idx] + flow[key_idx+1:] |
| |
|
| | flow_confid = torch.cat(flow_confid, dim=0) |
| | flow = torch.cat(flow, dim=0) |
| | flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') |
| | |
| | edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] |
| |
|
| | src_idx = edge_idx[..., :, :] |
| |
|
| | A = H * W |
| | src_idx_tensor = src_idx[0, : ,0] |
| | f = src_idx_tensor // A |
| | id = src_idx_tensor % A |
| | x = id % W |
| | y = id // W |
| |
|
| | |
| | src_fxy = torch.stack((f, x, y), dim=1) |
| | |
| | grid = coords_grid(B-1, H, W).to(flow.device) + flow |
| |
|
| | x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long() |
| | y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long() |
| | tar_xy = torch.stack((x, y), dim=1) |
| | tar_idx = y * W + x |
| | tar_idx = rearrange(tar_idx, ' d -> 1 d 1') |
| |
|
| | self.flow_correspondence[H] = (src_idx, tar_idx) |
| |
|
| | def set_merge(self, merge, unmerge): |
| | self.step_store["merge"] = merge |
| | self.step_store["unmerge"] = unmerge |
| |
|
| | def set_warp(self, flows, masks, flow_confids=None): |
| | self.step_store["flows"] = flows |
| | self.step_store["occ_masks"] = masks |
| | if flow_confids is not None: |
| | self.step_store["flow_confids"] = flow_confids |
| |
|
| | def set_warp2(self, flows, flow_confids): |
| | self.step_store["flows2"] = flows |
| | self.step_store["flow_confids2"] = flow_confids |
| |
|
| | def set_pre_keyframe_lq(self, pre_keyframe_lq): |
| | self.step_store["pre_keyframe_lq"] = pre_keyframe_lq |
| |
|
| | def __call__(self, context, is_cross: bool, place_in_unet: str): |
| | context = self.forward(context, is_cross, place_in_unet) |
| | return context |
| |
|
| | def set_cur_frame_idx(self, frame_idx): |
| | self.cur_frame_idx = frame_idx |
| |
|
| | def set_step(self, step): |
| | self.cur_step = step |
| |
|
| | def set_total_step(self, total_step): |
| | self.total_step = total_step |
| | self.cur_index = 0 |
| |
|
| | def clear_store(self): |
| | del self.step_store |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | self.step_store = self.get_empty_store() |
| |
|
| | def set_task(self, task, restore_step=1.0): |
| | self.init_store = False |
| | self.restore = False |
| | self.update = False |
| | self.cur_index = 0 |
| | self.restore_step = restore_step |
| | self.updatex0 = False |
| | self.restorex0 = False |
| | if 'initfirst' in task: |
| | self.init_store = True |
| | self.clear_store() |
| | if 'updatestyle' in task: |
| | self.update = True |
| | if 'keepstyle' in task: |
| | self.restore = True |
| | if 'updatex0' in task: |
| | self.updatex0 = True |
| | if 'keepx0' in task: |
| | self.restorex0 = True |
| |
|