| import torch |
| from torch.autograd import Variable |
| from collections import OrderedDict |
| import numpy as np |
| import os |
| from PIL import Image |
| import util.util as util |
| from .base_model import BaseModel |
| from . import networks |
|
|
| class UIModel(BaseModel): |
| def name(self): |
| return 'UIModel' |
|
|
| def initialize(self, opt): |
| assert(not opt.isTrain) |
| BaseModel.initialize(self, opt) |
| self.use_features = opt.instance_feat or opt.label_feat |
|
|
| netG_input_nc = opt.label_nc |
| if not opt.no_instance: |
| netG_input_nc += 1 |
| if self.use_features: |
| netG_input_nc += opt.feat_num |
|
|
| self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, |
| opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, |
| opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) |
| self.load_network(self.netG, 'G', opt.which_epoch) |
|
|
| print('---------- Networks initialized -------------') |
|
|
| def toTensor(self, img, normalize=False): |
| tensor = torch.from_numpy(np.array(img, np.int32, copy=False)) |
| tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode)) |
| tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous() |
| if normalize: |
| return (tensor.float()/255.0 - 0.5) / 0.5 |
| return tensor.float() |
|
|
| def load_image(self, label_path, inst_path, feat_path): |
| opt = self.opt |
| |
| label_img = Image.open(label_path) |
| if label_path.find('face') != -1: |
| label_img = label_img.convert('L') |
| ow, oh = label_img.size |
| w = opt.loadSize |
| h = int(w * oh / ow) |
| label_img = label_img.resize((w, h), Image.NEAREST) |
| label_map = self.toTensor(label_img) |
| |
| |
| self.label_map = label_map.cuda() |
| oneHot_size = (1, opt.label_nc, h, w) |
| input_label = self.Tensor(torch.Size(oneHot_size)).zero_() |
| self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0) |
|
|
| |
| if not opt.no_instance: |
| inst_img = Image.open(inst_path) |
| inst_img = inst_img.resize((w, h), Image.NEAREST) |
| self.inst_map = self.toTensor(inst_img).cuda() |
| self.edge_map = self.get_edges(self.inst_map) |
| self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True) |
| else: |
| self.net_input = Variable(self.input_label, volatile=True) |
| |
| self.features_clustered = np.load(feat_path).item() |
| self.object_map = self.inst_map if opt.instance_feat else self.label_map |
| |
| object_np = self.object_map.cpu().numpy().astype(int) |
| self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_() |
| self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8) |
| for i in np.unique(object_np): |
| label = i if i < 1000 else i//1000 |
| if label in self.features_clustered: |
| feat = self.features_clustered[label] |
| np.random.seed(i+1) |
| cluster_idx = np.random.randint(0, feat.shape[0]) |
| self.cluster_indices[label] = cluster_idx |
| idx = (self.object_map == i).nonzero() |
| self.set_features(idx, feat, cluster_idx) |
|
|
| self.net_input_original = self.net_input.clone() |
| self.label_map_original = self.label_map.clone() |
| self.feat_map_original = self.feat_map.clone() |
| if not opt.no_instance: |
| self.inst_map_original = self.inst_map.clone() |
|
|
| def reset(self): |
| self.net_input = self.net_input_prev = self.net_input_original.clone() |
| self.label_map = self.label_map_prev = self.label_map_original.clone() |
| self.feat_map = self.feat_map_prev = self.feat_map_original.clone() |
| if not self.opt.no_instance: |
| self.inst_map = self.inst_map_prev = self.inst_map_original.clone() |
| self.object_map = self.inst_map if self.opt.instance_feat else self.label_map |
|
|
| def undo(self): |
| self.net_input = self.net_input_prev |
| self.label_map = self.label_map_prev |
| self.feat_map = self.feat_map_prev |
| if not self.opt.no_instance: |
| self.inst_map = self.inst_map_prev |
| self.object_map = self.inst_map if self.opt.instance_feat else self.label_map |
| |
| |
| def get_edges(self, t): |
| edge = torch.cuda.ByteTensor(t.size()).zero_() |
| edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) |
| edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) |
| edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) |
| edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) |
| return edge.float() |
|
|
| |
| def change_labels(self, click_src, click_tgt): |
| y_src, x_src = click_src[0], click_src[1] |
| y_tgt, x_tgt = click_tgt[0], click_tgt[1] |
| label_src = int(self.label_map[0, 0, y_src, x_src]) |
| inst_src = self.inst_map[0, 0, y_src, x_src] |
| label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt]) |
| inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt] |
|
|
| idx_src = (self.inst_map == inst_src).nonzero() |
| |
| if idx_src.shape: |
| |
| self.backup_current_state() |
|
|
| |
| self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt |
| self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0 |
| self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 |
| |
| |
| if inst_tgt > 1000: |
| |
| tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000) |
| inst_tgt = self.inst_map[tgt_indices].max() + 1 |
| self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt |
| self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) |
|
|
| |
| idx_tgt = (self.inst_map == inst_tgt).nonzero() |
| if idx_tgt.shape: |
| self.copy_features(idx_src, idx_tgt[0,:]) |
|
|
| self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) |
|
|
| |
| def add_strokes(self, click_src, label_tgt, bw, save): |
| |
| size = self.net_input.size() |
| h, w = size[2], size[3] |
| idx_src = torch.LongTensor(bw**2, 4).fill_(0) |
| for i in range(bw): |
| idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i)) |
| for j in range(bw): |
| idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j)) |
| idx_src = idx_src.cuda() |
| |
| |
| if idx_src.shape: |
| |
| if save: |
| self.backup_current_state() |
|
|
| |
| self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt |
| for k in range(self.opt.label_nc): |
| self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0 |
| self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 |
|
|
| |
| self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt |
| self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) |
| |
| |
| if self.opt.instance_feat: |
| feat = self.features_clustered[label_tgt] |
| |
| |
| cluster_idx = self.cluster_indices[label_tgt] |
| self.set_features(idx_src, feat, cluster_idx) |
| |
| self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) |
|
|
| |
| def add_objects(self, click_src, label_tgt, mask, style_id=0): |
| y, x = click_src[0], click_src[1] |
| mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...] |
| idx_src = torch.from_numpy(mask).cuda().nonzero() |
| idx_src[:,2] += y |
| idx_src[:,3] += x |
|
|
| |
| self.backup_current_state() |
|
|
| |
| self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt |
| for k in range(self.opt.label_nc): |
| self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0 |
| self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 |
|
|
| |
| self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt |
| self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) |
| |
| |
| self.set_features(idx_src, self.feat, style_id) |
| |
| self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) |
|
|
| def single_forward(self, net_input, feat_map): |
| net_input = torch.cat((net_input, feat_map), dim=1) |
| fake_image = self.netG.forward(net_input) |
|
|
| if fake_image.size()[0] == 1: |
| return fake_image.data[0] |
| return fake_image.data |
|
|
|
|
| |
| def style_forward(self, click_pt, style_id=-1): |
| if click_pt is None: |
| self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) |
| self.crop = None |
| self.mask = None |
| else: |
| instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]]) |
| self.instToChange = instToChange |
| label = instToChange if instToChange < 1000 else instToChange//1000 |
| self.feat = self.features_clustered[label] |
| self.fake_image = [] |
| self.mask = self.object_map == instToChange |
| idx = self.mask.nonzero() |
| self.get_crop_region(idx) |
| if idx.size(): |
| if style_id == -1: |
| (min_y, min_x, max_y, max_x) = self.crop |
| |
| for cluster_idx in range(self.opt.multiple_output): |
| self.set_features(idx, self.feat, cluster_idx) |
| fake_image = self.single_forward(self.net_input, self.feat_map) |
| fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x]) |
| self.fake_image.append(fake_image) |
| """### To speed up previewing different style results, either crop or downsample the label maps |
| if instToChange > 1000: |
| (min_y, min_x, max_y, max_x) = self.crop |
| ### crop |
| _, _, h, w = self.net_input.size() |
| offset = 512 |
| y_start, x_start = max(0, min_y-offset), max(0, min_x-offset) |
| y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset)) |
| y_region = slice(y_start, y_start+(y_end-y_start)//16*16) |
| x_region = slice(x_start, x_start+(x_end-x_start)//16*16) |
| net_input = self.net_input[:,:,y_region,x_region] |
| for cluster_idx in range(self.opt.multiple_output): |
| self.set_features(idx, self.feat, cluster_idx) |
| fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region]) |
| fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start]) |
| self.fake_image.append(fake_image) |
| else: |
| ### downsample |
| (min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop] |
| net_input = self.net_input[:,:,::2,::2] |
| size = net_input.size() |
| net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3]) |
| for cluster_idx in range(self.opt.multiple_output): |
| self.set_features(idx, self.feat, cluster_idx) |
| feat_map = self.feat_map[:,:,::2,::2] |
| if cluster_idx == 0: |
| feat_map_batch = feat_map |
| else: |
| feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0) |
| fake_image_batch = self.single_forward(net_input_batch, feat_map_batch) |
| for i in range(self.opt.multiple_output): |
| self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))""" |
| |
| else: |
| self.set_features(idx, self.feat, style_id) |
| self.cluster_indices[label] = style_id |
| self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) |
|
|
| def backup_current_state(self): |
| self.net_input_prev = self.net_input.clone() |
| self.label_map_prev = self.label_map.clone() |
| self.inst_map_prev = self.inst_map.clone() |
| self.feat_map_prev = self.feat_map.clone() |
|
|
| |
| def get_crop_region(self, idx): |
| size = self.net_input.size() |
| h, w = size[2], size[3] |
| min_y, min_x = idx[:,2].min(), idx[:,3].min() |
| max_y, max_x = idx[:,2].max(), idx[:,3].max() |
| crop_min = 128 |
| if max_y - min_y < crop_min: |
| min_y = max(0, (max_y + min_y) // 2 - crop_min // 2) |
| max_y = min(h-1, min_y + crop_min) |
| if max_x - min_x < crop_min: |
| min_x = max(0, (max_x + min_x) // 2 - crop_min // 2) |
| max_x = min(w-1, min_x + crop_min) |
| self.crop = (min_y, min_x, max_y, max_x) |
| self.mask = self.mask[:,:, min_y:max_y, min_x:max_x] |
|
|
| |
| def update_features(self, cluster_idx, mask=None, click_pt=None): |
| self.feat_map_prev = self.feat_map.clone() |
| |
| if mask is not None: |
| y, x = click_pt[0], click_pt[1] |
| mask = np.transpose(mask, (2,0,1))[np.newaxis,...] |
| idx = torch.from_numpy(mask).cuda().nonzero() |
| idx[:,2] += y |
| idx[:,3] += x |
| |
| else: |
| idx = (self.object_map == self.instToChange).nonzero() |
|
|
| |
| self.set_features(idx, self.feat, cluster_idx) |
|
|
| |
| def set_features(self, idx, feat, cluster_idx): |
| for k in range(self.opt.feat_num): |
| self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] |
|
|
| |
| def copy_features(self, idx_src, idx_tgt): |
| for k in range(self.opt.feat_num): |
| val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]] |
| self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val |
|
|
| def get_current_visuals(self, getLabel=False): |
| mask = self.mask |
| if self.mask is not None: |
| mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8) |
|
|
| dict_list = [('fake_image', self.fake_image), ('mask', mask)] |
|
|
| if getLabel: |
| label = util.tensor2label(self.net_input.data[0], self.opt.label_nc) |
| dict_list += [('label', label)] |
|
|
| return OrderedDict(dict_list) |