From 991f2a7ef30ddf4ab0e8181e2d94ff12ca75c547 Mon Sep 17 00:00:00 2001 From: haoyuying <18844182690@163.com> Date: Sun, 27 Sep 2020 17:11:27 +0800 Subject: [PATCH] revise user-guided colorization second time --- .../user_guided_colorization/module.py | 94 ++++--------------- paddlehub/module/cv_module.py | 48 +++++----- paddlehub/process/transforms.py | 37 ++++---- 3 files changed, 63 insertions(+), 116 deletions(-) diff --git a/hub_module/modules/image/colorization/user_guided_colorization/module.py b/hub_module/modules/image/colorization/user_guided_colorization/module.py index ebba9933..42f2b855 100644 --- a/hub_module/modules/image/colorization/user_guided_colorization/module.py +++ b/hub_module/modules/image/colorization/user_guided_colorization/module.py @@ -46,89 +46,39 @@ class UserGuidedColorization(nn.Layer): self.output_nc = 2 self.classification = classification # Conv1 - model1 = ( - Conv2d(self.input_nc, 64, 3, 1, 1), - nn.ReLU(), - Conv2d(64, 64, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(64), - ) + model1 = (Conv2d(self.input_nc, 64, 3, 1, 1), nn.ReLU(), Conv2d(64, 64, 3, 1, 1), nn.ReLU(), nn.BatchNorm(64)) # Conv2 - model2 = ( - Conv2d(64, 128, 3, 1, 1), - nn.ReLU(), - Conv2d(128, 128, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(128), - ) + model2 = (Conv2d(64, 128, 3, 1, 1), nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128)) # Conv3 - model3 = ( - Conv2d(128, 256, 3, 1, 1), - nn.ReLU(), - Conv2d(256, 256, 3, 1, 1), - nn.ReLU(), - Conv2d(256, 256, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(256), - ) + model3 = (Conv2d(128, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1, + 1), nn.ReLU(), Conv2d(256, 256, 3, 1, + 1), nn.ReLU(), nn.BatchNorm(256)) # Conv4 - model4 = ( - Conv2d(256, 512, 3, 1, 1), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 1), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(512), - ) + model4 = (Conv2d(256, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1, + 1), nn.ReLU(), Conv2d(512, 512, 3, 1, + 1), nn.ReLU(), nn.BatchNorm(512)) # Conv5 - model5 = ( - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - nn.BatchNorm(512), - ) + model5 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), + Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512)) # Conv6 - model6 = ( - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 2, 2), - nn.ReLU(), - nn.BatchNorm(512), - ) + model6 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), + Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512)) # Conv7 - model7 = ( - Conv2d(512, 512, 3, 1, 1), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 1), - nn.ReLU(), - Conv2d(512, 512, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(512), - ) + model7 = (Conv2d(512, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1, + 1), nn.ReLU(), Conv2d(512, 512, 3, 1, + 1), nn.ReLU(), nn.BatchNorm(512)) # Conv8 model8up = (ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), ) model3short8 = (Conv2d(256, 256, 3, 1, 1), ) - model8 = ( - nn.ReLU(), - Conv2d(256, 256, 3, 1, 1), - nn.ReLU(), - Conv2d(256, 256, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(256), - ) + model8 = (nn.ReLU(), Conv2d(256, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1, + 1), nn.ReLU(), nn.BatchNorm(256)) # Conv9 model9up = (ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), ) @@ -138,13 +88,8 @@ class UserGuidedColorization(nn.Layer): 3, 1, 1, - ), ) - model9 = ( - nn.ReLU(), - Conv2d(128, 128, 3, 1, 1), - nn.ReLU(), - nn.BatchNorm(128), - ) + )) + model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128)) # Conv10 model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), ) @@ -212,6 +157,7 @@ class UserGuidedColorization(nn.Layer): mask_B: paddle.Tensor, real_b: paddle.Tensor = None, real_B_enc: paddle.Tensor = None) -> paddle.Tensor: + conv1_2 = self.model1(paddle.concat([input_A, input_B, mask_B], axis=1)) conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 80936bd6..29cc93f3 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -128,26 +128,30 @@ class ImageColorizeModule(RunModule, ImageServing): ''' out_class, out_reg = self(batch[0], batch[1], batch[2]) + # loss criterionCE = nn.loss.CrossEntropyLoss() loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss = loss_ce + loss_G_L1_reg - - visual_ret = OrderedDict() + # psnr psnrs = [] - lab2rgb = ConvertColorSpace(mode='LAB2RGB') + visual_ret = OrderedDict() process = ColorPostprocess() + lab2rgb = ConvertColorSpace(mode='LAB2RGB') for i in range(batch[0].numpy().shape[0]): real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i] visual_ret['real'] = process(real) + fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = process(fake) + mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnrs.append(psnr_value) - psnr = paddle.to_variable(np.array(psnrs)) + + psnr = paddle.to_tensor(np.array(psnrs)) return {'loss': loss, 'metrics': {'psnr': psnr}} @@ -157,28 +161,32 @@ class ImageColorizeModule(RunModule, ImageServing): Args: images(str) : Images path to be colorized. - visualization(bool): Whether to save colorized images. - save_path(str) : Path to save colorized images. + visualization(bool): Whether to save colorized images, default is True. + save_path(str) : Path to save colorized images, default is result/. Returns: results(list[dict]) : The prediction result of each input image ''' - lab2rgb = ConvertColorSpace(mode='LAB2RGB') - process = ColorPostprocess() - resize = Resize((256, 256)) - visual_ret = OrderedDict() + im = self.transforms(images, is_train=False) out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']), paddle.to_variable(im['mask_B'])) result = [] + lab2rgb = ConvertColorSpace(mode='LAB2RGB') + process = ColorPostprocess() + resize = Resize((256, 256)) + visual_ret = OrderedDict() for i in range(im['A'].shape[0]): gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i] visual_ret['gray'] = resize(process(gray)) + hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i] visual_ret['hint'] = resize(process(hint)) + real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i] visual_ret['real'] = resize(process(real)) + fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = resize(process(fake)) @@ -190,8 +198,6 @@ class ImageColorizeModule(RunModule, ImageServing): visual_gray = Image.fromarray(visual_ret['fake_reg']) visual_gray.save(fake_path) - mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) - psnr_value = 20 * np.log10(255. / np.sqrt(mse)) result.append(visual_ret) return result @@ -221,29 +227,31 @@ class StyleTransferModule(RunModule, ImageServing): Returns: results(dict) : The model outputs, such as metrics. ''' + mse_loss = nn.MSELoss() N, C, H, W = batch[0].shape + batch[1] = batch[1][0].unsqueeze(0) self.setTarget(batch[1]) - y = self(batch[0]) - xc = paddle.to_tensor(batch[0].numpy().copy()) y = subtract_imagenet_mean_batch(y) - xc = subtract_imagenet_mean_batch(xc) features_y = self.getFeature(y) + + xc = paddle.to_tensor(batch[0].numpy().copy()) + xc = subtract_imagenet_mean_batch(xc) features_xc = self.getFeature(xc) f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True) - content_loss = mse_loss(features_y[1], f_xc_c) batch[1] = subtract_imagenet_mean_batch(batch[1]) features_style = self.getFeature(batch[1]) gram_style = [gram_matrix(y) for y in features_style] + style_loss = 0. for m in range(len(features_y)): gram_y = gram_matrix(features_y[m]) gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1))) style_loss += mse_loss(gram_y, gram_s[:N, :, :]) - + content_loss = mse_loss(features_y[1], f_xc_c) loss = content_loss + style_loss return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}} @@ -261,10 +269,8 @@ class StyleTransferModule(RunModule, ImageServing): Returns: output(np.ndarray) : The style transformed images with bgr mode. ''' - content = paddle.to_tensor(self.transform(origin_path)) - style = paddle.to_tensor(self.transform(style_path)) - content = content.unsqueeze(0) - style = style.unsqueeze(0) + content = paddle.to_tensor(self.transform(origin_path)).unsqueeze(0) + style = paddle.to_tensor(self.transform(style_path)).unsqueeze(0) self.setTarget(style) output = self(content) diff --git a/paddlehub/process/transforms.py b/paddlehub/process/transforms.py index f3fe4d8b..536e3c22 100644 --- a/paddlehub/process/transforms.py +++ b/paddlehub/process/transforms.py @@ -241,12 +241,7 @@ class RandomPaddingCrop: pad_height = max(crop_height - img_height, 0) pad_width = max(crop_width - img_width, 0) if (pad_height > 0 or pad_width > 0): - im = cv2.copyMakeBorder(im, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, + im = cv2.copyMakeBorder(im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, \ value=self.im_padding_value) if crop_height > 0 and crop_width > 0: @@ -302,11 +297,7 @@ class RandomRotation: r[0, 2] += (nw / 2) - cx r[1, 2] += (nh / 2) - cy dsize = (nw, nh) - im = cv2.warpAffine(im, - r, - dsize=dsize, - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, + im = cv2.warpAffine(im, r, dsize=dsize, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, \ borderValue=self.im_padding_value) return im @@ -371,7 +362,7 @@ class RandomDistort: saturation_upper = 1 + self.saturation_range hue_lower = -self.hue_range hue_upper = self.hue_range - ops = [brightness, contrast, saturation, hue] + ops = ["brightness", "contrast", "saturation", "hue"] random.shuffle(ops) params_dict = { 'brightness': { @@ -553,7 +544,8 @@ class ConvertColorSpace: class ColorizeHint: """Get hint and mask images for colorization. - This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process. + This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain + the local hints and correspoding mask to guid colorization process. Args: percent(float): Probability for ignoring hint in an iteration. @@ -577,6 +569,7 @@ class ColorizeHint: self.hint = hint self.mask = mask N, C, H, W = data.shape + for nn in range(N): pp = 0 cont_cond = True @@ -599,11 +592,9 @@ class ColorizeHint: # add color point if self.use_avg: # embed() - hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P], - axis=2, - keepdims=True), - axis=1, - keepdims=True).reshape(1, C, 1, 1) + hint[nn, :, h:h + P, w:w + P] = \ + np.mean(np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), + axis=1, keepdims=True).reshape(1, C, 1, 1) else: hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] mask[nn, :, h:h + P, w:w + P] = 1 @@ -667,7 +658,8 @@ class ColorizePreprocess: def __call__(self, data_lab: np.ndarray): """ - This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task. + This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for + colorization task. Args: img(np.ndarray): LAB image. @@ -681,20 +673,23 @@ class ColorizePreprocess: 0, ], :, :] data['B'] = data_lab[:, 1:, :, :] + if self.ab_thresh > 0: # mask out grayscale images thresh = 1. * self.ab_thresh / 110 - mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)), - axis=1) + mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - \ + np.min(np.min(data['B'], axis=3), axis=2)), axis=1) mask = (mask >= thresh) data['A'] = data['A'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :] if np.sum(mask) == 0: return None + data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] data['hint_B'] = np.zeros(shape=data['B'].shape) data['mask_B'] = np.zeros(shape=data['A'].shape) data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B']) + if self.is_train: data = self.squeeze(data) data['real_B_enc'] = data['real_B_enc'].astype(np.int64) -- GitLab