提交 991f2a7e 编写于 作者: H haoyuying

revise user-guided colorization second time

上级 c07d1ffe
...@@ -46,89 +46,39 @@ class UserGuidedColorization(nn.Layer): ...@@ -46,89 +46,39 @@ class UserGuidedColorization(nn.Layer):
self.output_nc = 2 self.output_nc = 2
self.classification = classification self.classification = classification
# Conv1 # Conv1
model1 = ( model1 = (Conv2d(self.input_nc, 64, 3, 1, 1), nn.ReLU(), Conv2d(64, 64, 3, 1, 1), nn.ReLU(), nn.BatchNorm(64))
Conv2d(self.input_nc, 64, 3, 1, 1),
nn.ReLU(),
Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(64),
)
# Conv2 # Conv2
model2 = ( model2 = (Conv2d(64, 128, 3, 1, 1), nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
# Conv3 # Conv3
model3 = ( model3 = (Conv2d(128, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1,
Conv2d(128, 256, 3, 1, 1), 1), nn.ReLU(), Conv2d(256, 256, 3, 1,
nn.ReLU(), 1), nn.ReLU(), nn.BatchNorm(256))
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
# Conv4 # Conv4
model4 = ( model4 = (Conv2d(256, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
Conv2d(256, 512, 3, 1, 1), 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
nn.ReLU(), 1), nn.ReLU(), nn.BatchNorm(512))
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv5 # Conv5
model5 = ( model5 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512))
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv6 # Conv6
model6 = ( model6 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512))
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv7 # Conv7
model7 = ( model7 = (Conv2d(512, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
Conv2d(512, 512, 3, 1, 1), 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
nn.ReLU(), 1), nn.ReLU(), nn.BatchNorm(512))
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv8 # Conv8
model8up = (ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), ) model8up = (ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), )
model3short8 = (Conv2d(256, 256, 3, 1, 1), ) model3short8 = (Conv2d(256, 256, 3, 1, 1), )
model8 = ( model8 = (nn.ReLU(), Conv2d(256, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1,
nn.ReLU(), 1), nn.ReLU(), nn.BatchNorm(256))
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
# Conv9 # Conv9
model9up = (ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), ) model9up = (ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), )
...@@ -138,13 +88,8 @@ class UserGuidedColorization(nn.Layer): ...@@ -138,13 +88,8 @@ class UserGuidedColorization(nn.Layer):
3, 3,
1, 1,
1, 1,
), ) ))
model9 = ( model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
# Conv10 # Conv10
model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), ) model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), )
...@@ -212,6 +157,7 @@ class UserGuidedColorization(nn.Layer): ...@@ -212,6 +157,7 @@ class UserGuidedColorization(nn.Layer):
mask_B: paddle.Tensor, mask_B: paddle.Tensor,
real_b: paddle.Tensor = None, real_b: paddle.Tensor = None,
real_B_enc: paddle.Tensor = None) -> paddle.Tensor: real_B_enc: paddle.Tensor = None) -> paddle.Tensor:
conv1_2 = self.model1(paddle.concat([input_A, input_B, mask_B], axis=1)) conv1_2 = self.model1(paddle.concat([input_A, input_B, mask_B], axis=1))
conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
......
...@@ -128,26 +128,30 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -128,26 +128,30 @@ class ImageColorizeModule(RunModule, ImageServing):
''' '''
out_class, out_reg = self(batch[0], batch[1], batch[2]) out_class, out_reg = self(batch[0], batch[1], batch[2])
# loss
criterionCE = nn.loss.CrossEntropyLoss() criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) loss_ce = criterionCE(out_class, batch[4][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg loss = loss_ce + loss_G_L1_reg
# psnr
visual_ret = OrderedDict()
psnrs = [] psnrs = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB') visual_ret = OrderedDict()
process = ColorPostprocess() process = ColorPostprocess()
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
for i in range(batch[0].numpy().shape[0]): for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i] real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real) visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake) visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value) psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs))
psnr = paddle.to_tensor(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}} return {'loss': loss, 'metrics': {'psnr': psnr}}
...@@ -157,28 +161,32 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -157,28 +161,32 @@ class ImageColorizeModule(RunModule, ImageServing):
Args: Args:
images(str) : Images path to be colorized. images(str) : Images path to be colorized.
visualization(bool): Whether to save colorized images. visualization(bool): Whether to save colorized images, default is True.
save_path(str) : Path to save colorized images. save_path(str) : Path to save colorized images, default is result/.
Returns: Returns:
results(list[dict]) : The prediction result of each input image results(list[dict]) : The prediction result of each input image
''' '''
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False) im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']), out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']),
paddle.to_variable(im['mask_B'])) paddle.to_variable(im['mask_B']))
result = [] result = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
visual_ret = OrderedDict()
for i in range(im['A'].shape[0]): for i in range(im['A'].shape[0]):
gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i] gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i]
visual_ret['gray'] = resize(process(gray)) visual_ret['gray'] = resize(process(gray))
hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i] hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i]
visual_ret['hint'] = resize(process(hint)) visual_ret['hint'] = resize(process(hint))
real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i] real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i]
visual_ret['real'] = resize(process(real)) visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake)) visual_ret['fake_reg'] = resize(process(fake))
...@@ -190,8 +198,6 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -190,8 +198,6 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_gray = Image.fromarray(visual_ret['fake_reg']) visual_gray = Image.fromarray(visual_ret['fake_reg'])
visual_gray.save(fake_path) visual_gray.save(fake_path)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret) result.append(visual_ret)
return result return result
...@@ -221,29 +227,31 @@ class StyleTransferModule(RunModule, ImageServing): ...@@ -221,29 +227,31 @@ class StyleTransferModule(RunModule, ImageServing):
Returns: Returns:
results(dict) : The model outputs, such as metrics. results(dict) : The model outputs, such as metrics.
''' '''
mse_loss = nn.MSELoss() mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0) batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1]) self.setTarget(batch[1])
y = self(batch[0]) y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y) y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y) features_y = self.getFeature(y)
xc = paddle.to_tensor(batch[0].numpy().copy())
xc = subtract_imagenet_mean_batch(xc)
features_xc = self.getFeature(xc) features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True) f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1]) batch[1] = subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1]) features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style] gram_style = [gram_matrix(y) for y in features_style]
style_loss = 0. style_loss = 0.
for m in range(len(features_y)): for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m]) gram_y = gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1))) gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :]) style_loss += mse_loss(gram_y, gram_s[:N, :, :])
content_loss = mse_loss(features_y[1], f_xc_c)
loss = content_loss + style_loss loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}} return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
...@@ -261,10 +269,8 @@ class StyleTransferModule(RunModule, ImageServing): ...@@ -261,10 +269,8 @@ class StyleTransferModule(RunModule, ImageServing):
Returns: Returns:
output(np.ndarray) : The style transformed images with bgr mode. output(np.ndarray) : The style transformed images with bgr mode.
''' '''
content = paddle.to_tensor(self.transform(origin_path)) content = paddle.to_tensor(self.transform(origin_path)).unsqueeze(0)
style = paddle.to_tensor(self.transform(style_path)) style = paddle.to_tensor(self.transform(style_path)).unsqueeze(0)
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style) self.setTarget(style)
output = self(content) output = self(content)
......
...@@ -241,12 +241,7 @@ class RandomPaddingCrop: ...@@ -241,12 +241,7 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(im, im = cv2.copyMakeBorder(im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, \
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value) value=self.im_padding_value)
if crop_height > 0 and crop_width > 0: if crop_height > 0 and crop_width > 0:
...@@ -302,11 +297,7 @@ class RandomRotation: ...@@ -302,11 +297,7 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine(im, im = cv2.warpAffine(im, r, dsize=dsize, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, \
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value) borderValue=self.im_padding_value)
return im return im
...@@ -371,7 +362,7 @@ class RandomDistort: ...@@ -371,7 +362,7 @@ class RandomDistort:
saturation_upper = 1 + self.saturation_range saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range hue_lower = -self.hue_range
hue_upper = self.hue_range hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue] ops = ["brightness", "contrast", "saturation", "hue"]
random.shuffle(ops) random.shuffle(ops)
params_dict = { params_dict = {
'brightness': { 'brightness': {
...@@ -553,7 +544,8 @@ class ConvertColorSpace: ...@@ -553,7 +544,8 @@ class ConvertColorSpace:
class ColorizeHint: class ColorizeHint:
"""Get hint and mask images for colorization. """Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process. This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain
the local hints and correspoding mask to guid colorization process.
Args: Args:
percent(float): Probability for ignoring hint in an iteration. percent(float): Probability for ignoring hint in an iteration.
...@@ -577,6 +569,7 @@ class ColorizeHint: ...@@ -577,6 +569,7 @@ class ColorizeHint:
self.hint = hint self.hint = hint
self.mask = mask self.mask = mask
N, C, H, W = data.shape N, C, H, W = data.shape
for nn in range(N): for nn in range(N):
pp = 0 pp = 0
cont_cond = True cont_cond = True
...@@ -599,11 +592,9 @@ class ColorizeHint: ...@@ -599,11 +592,9 @@ class ColorizeHint:
# add color point # add color point
if self.use_avg: if self.use_avg:
# embed() # embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P], hint[nn, :, h:h + P, w:w + P] = \
axis=2, np.mean(np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True),
keepdims=True), axis=1, keepdims=True).reshape(1, C, 1, 1)
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else: else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1 mask[nn, :, h:h + P, w:w + P] = 1
...@@ -667,7 +658,8 @@ class ColorizePreprocess: ...@@ -667,7 +658,8 @@ class ColorizePreprocess:
def __call__(self, data_lab: np.ndarray): def __call__(self, data_lab: np.ndarray):
""" """
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task. This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for
colorization task.
Args: Args:
img(np.ndarray): LAB image. img(np.ndarray): LAB image.
...@@ -681,20 +673,23 @@ class ColorizePreprocess: ...@@ -681,20 +673,23 @@ class ColorizePreprocess:
0, 0,
], :, :] ], :, :]
data['B'] = data_lab[:, 1:, :, :] data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110 thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)), mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - \
axis=1) np.min(np.min(data['B'], axis=3), axis=2)), axis=1)
mask = (mask >= thresh) mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :] data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0: if np.sum(mask) == 0:
return None return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape) data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape) data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B']) data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
if self.is_train: if self.is_train:
data = self.squeeze(data) data = self.squeeze(data)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64) data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册