提交 991f2a7e 编写于 作者: H haoyuying

revise user-guided colorization second time

上级 c07d1ffe
......@@ -46,89 +46,39 @@ class UserGuidedColorization(nn.Layer):
self.output_nc = 2
self.classification = classification
# Conv1
model1 = (
Conv2d(self.input_nc, 64, 3, 1, 1),
nn.ReLU(),
Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(64),
)
model1 = (Conv2d(self.input_nc, 64, 3, 1, 1), nn.ReLU(), Conv2d(64, 64, 3, 1, 1), nn.ReLU(), nn.BatchNorm(64))
# Conv2
model2 = (
Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
model2 = (Conv2d(64, 128, 3, 1, 1), nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
# Conv3
model3 = (
Conv2d(128, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
model3 = (Conv2d(128, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1,
1), nn.ReLU(), Conv2d(256, 256, 3, 1,
1), nn.ReLU(), nn.BatchNorm(256))
# Conv4
model4 = (
Conv2d(256, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
model4 = (Conv2d(256, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
1), nn.ReLU(), Conv2d(512, 512, 3, 1,
1), nn.ReLU(), nn.BatchNorm(512))
# Conv5
model5 = (
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
model5 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512))
# Conv6
model6 = (
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
model6 = (Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2), nn.ReLU(), nn.BatchNorm(512))
# Conv7
model7 = (
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
model7 = (Conv2d(512, 512, 3, 1, 1), nn.ReLU(), Conv2d(512, 512, 3, 1,
1), nn.ReLU(), Conv2d(512, 512, 3, 1,
1), nn.ReLU(), nn.BatchNorm(512))
# Conv8
model8up = (ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), )
model3short8 = (Conv2d(256, 256, 3, 1, 1), )
model8 = (
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
model8 = (nn.ReLU(), Conv2d(256, 256, 3, 1, 1), nn.ReLU(), Conv2d(256, 256, 3, 1,
1), nn.ReLU(), nn.BatchNorm(256))
# Conv9
model9up = (ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), )
......@@ -138,13 +88,8 @@ class UserGuidedColorization(nn.Layer):
3,
1,
1,
), )
model9 = (
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
))
model9 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.ReLU(), nn.BatchNorm(128))
# Conv10
model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), )
......@@ -212,6 +157,7 @@ class UserGuidedColorization(nn.Layer):
mask_B: paddle.Tensor,
real_b: paddle.Tensor = None,
real_B_enc: paddle.Tensor = None) -> paddle.Tensor:
conv1_2 = self.model1(paddle.concat([input_A, input_B, mask_B], axis=1))
conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
......
......@@ -128,26 +128,30 @@ class ImageColorizeModule(RunModule, ImageServing):
'''
out_class, out_reg = self(batch[0], batch[1], batch[2])
# loss
criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg
visual_ret = OrderedDict()
# psnr
psnrs = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
visual_ret = OrderedDict()
process = ColorPostprocess()
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs))
psnr = paddle.to_tensor(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}}
......@@ -157,28 +161,32 @@ class ImageColorizeModule(RunModule, ImageServing):
Args:
images(str) : Images path to be colorized.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
visualization(bool): Whether to save colorized images, default is True.
save_path(str) : Path to save colorized images, default is result/.
Returns:
results(list[dict]) : The prediction result of each input image
'''
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']),
paddle.to_variable(im['mask_B']))
result = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
visual_ret = OrderedDict()
for i in range(im['A'].shape[0]):
gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i]
visual_ret['gray'] = resize(process(gray))
hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i]
visual_ret['hint'] = resize(process(hint))
real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i]
visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake))
......@@ -190,8 +198,6 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_gray = Image.fromarray(visual_ret['fake_reg'])
visual_gray.save(fake_path)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret)
return result
......@@ -221,29 +227,31 @@ class StyleTransferModule(RunModule, ImageServing):
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
xc = paddle.to_tensor(batch[0].numpy().copy())
xc = subtract_imagenet_mean_batch(xc)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
content_loss = mse_loss(features_y[1], f_xc_c)
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
......@@ -261,10 +269,8 @@ class StyleTransferModule(RunModule, ImageServing):
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
content = paddle.to_tensor(self.transform(origin_path)).unsqueeze(0)
style = paddle.to_tensor(self.transform(style_path)).unsqueeze(0)
self.setTarget(style)
output = self(content)
......
......@@ -241,12 +241,7 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
im = cv2.copyMakeBorder(im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, \
value=self.im_padding_value)
if crop_height > 0 and crop_width > 0:
......@@ -302,11 +297,7 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy
dsize = (nw, nh)
im = cv2.warpAffine(im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
im = cv2.warpAffine(im, r, dsize=dsize, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, \
borderValue=self.im_padding_value)
return im
......@@ -371,7 +362,7 @@ class RandomDistort:
saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range
hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue]
ops = ["brightness", "contrast", "saturation", "hue"]
random.shuffle(ops)
params_dict = {
'brightness': {
......@@ -553,7 +544,8 @@ class ConvertColorSpace:
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain
the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
......@@ -577,6 +569,7 @@ class ColorizeHint:
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
......@@ -599,11 +592,9 @@ class ColorizeHint:
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
hint[nn, :, h:h + P, w:w + P] = \
np.mean(np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True),
axis=1, keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
......@@ -667,7 +658,8 @@ class ColorizePreprocess:
def __call__(self, data_lab: np.ndarray):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for
colorization task.
Args:
img(np.ndarray): LAB image.
......@@ -681,20 +673,23 @@ class ColorizePreprocess:
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - \
np.min(np.min(data['B'], axis=3), axis=2)), axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
if self.is_train:
data = self.squeeze(data)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册