提交 8e7ab2a1 编写于 作者: W wuzewu

Merge branch 'pre-develop/v2.0.0' of https://github.com/PaddlePaddle/PaddleHub...

Merge branch 'pre-develop/v2.0.0' of https://github.com/PaddlePaddle/PaddleHub into pre-develop/v2.0.0
import paddle
import paddlehub as hub
import paddle.nn as nn
if __name__ == '__main__':
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
model.eval()
result = model.predict(images='sea.jpg')
\ No newline at end of file
import paddle
import paddlehub as hub
import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__':
is_train = True
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True)
color_set = Colorizedataset(transform=transform, mode=is_train)
if is_train:
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy
import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule
@moduleinfo(
name="user_guided_colorization",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="User_guided_colorization is a image colorization model, this module is trained with ILSVRC2012 dataset.",
version="1.0.0",
meta=ImageColorizeModule)
class UserGuidedColorization(nn.Layer):
"""Userguidedcolorization, see https://github.com/haoyuying/colorization-pytorch
Args:
use_tanh (bool): Whether to use tanh as final activation function.
classification (bool): Whether to switch classification branch for optimization.
load_checkpoint (str): Pretrained checkpoint path.
"""
def __init__(self, use_tanh: bool = True, classification: bool = True, load_checkpoint: str = None):
super(UserGuidedColorization, self).__init__()
self.input_nc = 4
self.output_nc = 2
self.classification = classification
# Conv1
model1 = (
Conv2d(self.input_nc, 64, 3, 1, 1),
nn.ReLU(),
Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(64),
)
# Conv2
model2 = (
Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
# Conv3
model3 = (
Conv2d(128, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
# Conv4
model4 = (
Conv2d(256, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv5
model5 = (
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv6
model6 = (
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 2, 2),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv7
model7 = (
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
Conv2d(512, 512, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(512),
)
# Conv8
model8up = (ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), )
model3short8 = (Conv2d(256, 256, 3, 1, 1), )
model8 = (
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(256),
)
# Conv9
model9up = (ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), )
model2short9 = (Conv2d(
128,
128,
3,
1,
1,
), )
model9 = (
nn.ReLU(),
Conv2d(128, 128, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm(128),
)
# Conv10
model10up = (ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), )
model1short10 = (Conv2d(64, 128, 3, 1, 1), )
model10 = (nn.ReLU(), Conv2d(128, 128, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2))
model_class = (Conv2d(256, 529, 1), )
if use_tanh:
model_out = (Conv2d(128, 2, 1, 1, 0, 1), nn.Tanh())
else:
model_out = (Conv2d(128, 2, 1, 1, 0, 1), )
self.model1 = nn.Sequential(*model1)
self.model2 = nn.Sequential(*model2)
self.model3 = nn.Sequential(*model3)
self.model4 = nn.Sequential(*model4)
self.model5 = nn.Sequential(*model5)
self.model6 = nn.Sequential(*model6)
self.model7 = nn.Sequential(*model7)
self.model8up = nn.Sequential(*model8up)
self.model8 = nn.Sequential(*model8)
self.model9up = nn.Sequential(*model9up)
self.model9 = nn.Sequential(*model9)
self.model10up = nn.Sequential(*model10up)
self.model10 = nn.Sequential(*model10)
self.model3short8 = nn.Sequential(*model3short8)
self.model2short9 = nn.Sequential(*model2short9)
self.model1short10 = nn.Sequential(*model1short10)
self.model_class = nn.Sequential(*model_class)
self.model_out = nn.Sequential(*model_out)
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained model success")
def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
else:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
return transform(images)
def forward(self,
input_A: paddle.Tensor,
input_B: paddle.Tensor,
mask_B: paddle.Tensor,
real_b: paddle.Tensor = None,
real_B_enc: paddle.Tensor = None) -> paddle.Tensor:
conv1_2 = self.model1(paddle.concat([input_A, input_B, mask_B], axis=1))
conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
conv5_3 = self.model5(conv4_3)
conv6_3 = self.model6(conv5_3)
conv7_3 = self.model7(conv6_3)
conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
conv8_3 = self.model8(conv8_up)
if self.classification:
out_class = self.model_class(conv8_3)
conv9_up = self.model9up(conv8_3.detach()) + self.model2short9(conv2_2.detach())
conv9_3 = self.model9(conv9_up)
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2.detach())
conv10_2 = self.model10(conv10_up)
out_reg = self.model_out(conv10_2)
else:
out_class = self.model_class(conv8_3.detach())
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
conv9_3 = self.model9(conv9_up)
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
conv10_2 = self.model10(conv10_up)
out_reg = self.model_out(conv10_2)
return out_class, out_reg
if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = UserGuidedColorization()
model.eval()
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy
import paddle
from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
from typing import Callable
class Colorizedataset(paddle.io.Dataset):
"""
Dataset for colorization.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
if self.mode == 'train':
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file)
def __getitem__(self, idx: int) -> numpy.ndarray:
img_path = self.data[idx]
im = self.transform(img_path)
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc']
def __len__(self):
return len(self.data)
\ No newline at end of file
...@@ -13,14 +13,20 @@ ...@@ -13,14 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
import os
from typing import List from typing import List
from collections import OrderedDict
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from PIL import Image
from paddlehub.module.module import serving, RunModule from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2 from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize
class ImageServing(object): class ImageServing(object):
...@@ -91,3 +97,96 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -91,3 +97,96 @@ class ImageClassifierModule(RunModule, ImageServing):
res_dict[class_name] = preds[i][k] res_dict[class_name] = preds[i][k]
res.append(res_dict) res.append(res_dict)
return res return res
class ImageColorizeModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
out_class, out_reg = self(batch[0], batch[1], batch[2])
criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg
visual_ret = OrderedDict()
psnrs = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs))
return {'loss': loss, 'metrics': {'psnr': psnr}}
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
images(str) : Images path to be colorized.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
results(list[dict]) : The prediction result of each input image
'''
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']),
paddle.to_variable(im['mask_B']))
result = []
for i in range(im['A'].shape[0]):
gray = lab2rgb(np.concatenate((im['A'], np.zeros(im['B'].shape)), axis=1))[i]
visual_ret['gray'] = resize(process(gray))
hint = lab2rgb(np.concatenate((im['A'], im['hint_B']), axis=1))[i]
visual_ret['hint'] = resize(process(hint))
real = lab2rgb(np.concatenate((im['A'], im['B']), axis=1))[i]
visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake))
if visualization:
fake_name = "fake_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
fake_path = os.path.join(save_path, fake_name)
visual_gray = Image.fromarray(visual_ret['fake_reg'])
visual_gray.save(fake_path)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret)
return result
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
...@@ -96,3 +98,23 @@ def rotate(im, rotate_lower, rotate_upper): ...@@ -96,3 +98,23 @@ def rotate(im, rotate_lower, rotate_upper):
rotate_delta = np.random.uniform(rotate_lower, rotate_upper) rotate_delta = np.random.uniform(rotate_lower, rotate_upper)
im = im.rotate(int(rotate_delta)) im = im.rotate(int(rotate_delta))
return im return im
def is_image_file(filename: str) -> bool:
'''Determine whether the input file name is a valid image file name.'''
ext = os.path.splitext(filename)[-1].lower()
return ext in ['.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff']
def get_img_file(dir_name: str) -> list:
'''Get all image file paths in several directories which have the same parent directory.'''
images = []
for parent, dirnames, filenames in os.walk(dir_name):
for filename in filenames:
if not is_image_file(filename):
continue
img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path)
images.sort()
return images
\ No newline at end of file
...@@ -16,15 +16,15 @@ ...@@ -16,15 +16,15 @@
import random import random
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import cv2
from paddlehub.process.functional import * from paddlehub.process.functional import *
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True): def __init__(self, transforms, to_rgb=True, stay_rgb=False):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
...@@ -32,6 +32,7 @@ class Compose: ...@@ -32,6 +32,7 @@ class Compose:
'must be equal or larger than 1!') 'must be equal or larger than 1!')
self.transforms = transforms self.transforms = transforms
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.stay_rgb = stay_rgb
def __call__(self, im): def __call__(self, im):
if isinstance(im, str): if isinstance(im, str):
...@@ -44,10 +45,15 @@ class Compose: ...@@ -44,10 +45,15 @@ class Compose:
for op in self.transforms: for op in self.transforms:
im = op(im) im = op(im)
im = permute(im)
if not self.stay_rgb:
im = permute(im)
return im return im
class RandomHorizontalFlip: class RandomHorizontalFlip:
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -396,3 +402,314 @@ class RandomDistort: ...@@ -396,3 +402,314 @@ class RandomDistort:
im = np.asarray(im).astype('float32') im = np.asarray(im).astype('float32')
return im return im
class ConvertColorSpace:
"""
Convert color space from RGB to LAB or from LAB to RGB.
Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return:
img(np.ndarray): converted image.
"""
def __init__(self, mode: str = 'RGB2LAB'):
self.mode = mode
def rgb2xyz(self, rgb: np.ndarray) -> np.ndarray:
"""
Convert color space from RGB to XYZ.
Args:
img(np.ndarray): Original RGB image.
Return:
img(np.ndarray): Converted XYZ image.
"""
mask = (rgb > 0.04045)
np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055) ** 2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] + .950227 * rgb[:, 2, :, :]
out = np.concatenate((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), axis=1)
return out
def xyz2lab(self, xyz: np.ndarray) -> np.ndarray:
"""
Convert color space from XYZ to LAB.
Args:
img(np.ndarray): Original XYZ image.
Return:
img(np.ndarray): Converted LAB image.
"""
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
xyz_scale = xyz / sc
mask = (xyz_scale > .008856).astype(np.float32)
xyz_int = np.cbrt(xyz_scale) * mask + (7.787 * xyz_scale + 16. / 116.) * (1 - mask)
L = 116. * xyz_int[:, 1, :, :] - 16.
a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :])
b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :])
out = np.concatenate((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), axis=1)
return out
def rgb2lab(self, rgb: np.ndarray) -> np.ndarray:
"""
Convert color space from RGB to LAB.
Args:
img(np.ndarray): Original RGB image.
Return:
img(np.ndarray): Converted LAB image.
"""
lab = self.xyz2lab(self.rgb2xyz(rgb))
l_rs = (lab[:, [0], :, :] - 50) / 100
ab_rs = lab[:, 1:, :, :] / 110
out = np.concatenate((l_rs, ab_rs), axis=1)
return out
def xyz2rgb(self, xyz: np.ndarray) -> np.ndarray:
"""
Convert color space from XYZ to RGB.
Args:
img(np.ndarray): Original XYZ image.
Return:
img(np.ndarray): Converted RGB image.
"""
r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] - 0.49853633 * xyz[:, 2, :, :]
g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] + .04155593 * xyz[:, 2, :, :]
b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] + 1.05731107 * xyz[:, 2, :, :]
rgb = np.concatenate((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), axis=1)
rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs
mask = (rgb > .0031308).astype(np.float32)
np.seterr(invalid='ignore')
out = (1.055 * (rgb ** (1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = np.nan_to_num(out)
return out
def lab2xyz(self, lab: np.ndarray) -> np.ndarray:
"""
Convert color space from LAB to XYZ.
Args:
img(np.ndarray): Original LAB image.
Return:
img(np.ndarray): Converted XYZ image.
"""
y_int = (lab[:, 0, :, :] + 16.) / 116.
x_int = (lab[:, 1, :, :] / 500.) + y_int
z_int = y_int - (lab[:, 2, :, :] / 200.)
z_int = np.maximum(z_int, 0)
out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1)
mask = (out > .2068966).astype(np.float32)
np.seterr(invalid='ignore')
out = (out ** 3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = np.nan_to_num(out)
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
out = out * sc
return out
def lab2rgb(self, lab_rs: np.ndarray) -> np.ndarray:
"""
Convert color space from LAB to RGB.
Args:
img(np.ndarray): Original LAB image.
Return:
img(np.ndarray): Converted RGB image.
"""
l = lab_rs[:, [0], :, :] * 100 + 50
ab = lab_rs[:, 1:, :, :] * 110
lab = np.concatenate((l, ab), axis=1)
out = self.xyz2rgb(self.lab2xyz(lab))
return out
def __call__(self, img: np.ndarray) -> np.ndarray:
if self.mode == 'RGB2LAB':
img = np.expand_dims(img / 255, 0)
img = np.array(img).transpose(0, 3, 1, 2)
return self.rgb2lab(img)
elif self.mode == 'LAB2RGB':
return self.lab2rgb(img)
else:
raise ValueError('The mode should be RGB2LAB or LAB2RGB')
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ]
self.data = data
self.hint = hint
self.mask = mask
N, C, H, W = data.shape
for nn in range(N):
pp = 0
cont_cond = True
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() < (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
continue
P = np.random.choice(sample_Ps) # patch size
# sample location
if self.samp == 'normal': # geometric distribution
h = int(np.clip(np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), 0, H - P))
w = int(np.clip(np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), 0, W - P))
else: # uniform distribution
h = np.random.randint(H - P + 1)
w = np.random.randint(W - P + 1)
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(
np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), axis=1, keepdims=True).reshape(
1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
# increment counter
pp += 1
mask -= 0.5
return hint, mask
class SqueezeAxis:
"""
Squeeze the specific axis when it equal to 1.
Args:
axis(int): Which axis should be squeezed.
"""
def __init__(self, axis: int):
self.axis = axis
def __call__(self, data: dict):
if isinstance(data, dict):
for key in data.keys():
data[key] = np.squeeze(data[key], 0).astype(np.float32)
return data
else:
raise TypeError("Type of data is invalid. Must be Dict or List or tuple, now is {}".format(type(data)))
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self, ab_thresh: float = 0.,
p: float = .125,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True,
is_train: bool = True):
self.ab_thresh = ab_thresh
self.p = p
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
self.is_train = is_train
self.gethint = ColorizeHint(percent=self.p, num_points=self.num_points, samp=self.samp, use_avg=self.use_avg)
self.squeeze = SqueezeAxis(0)
def __call__(self, data_lab: np.ndarray):
"""
This method seperates the L channel and AB channel, obtain hint, mask and real_B_enc as the input for colorization task.
Args:
img(np.ndarray): LAB image.
Returns:
data(dict):The preprocessed data for colorization.
"""
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0, ], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
if np.sum(mask) == 0:
return None
data_ab_rs = np.round((data['B'][:, :, ::4, ::4] * 110. + 110.) / 10.) # normalized bin number
data['real_B_enc'] = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :]
data['hint_B'] = np.zeros(shape=data['B'].shape)
data['mask_B'] = np.zeros(shape=data['A'].shape)
data['hint_B'], data['mask_B'] = self.gethint(data['B'], data['hint_B'], data['mask_B'])
if self.is_train:
data = self.squeeze(data)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
else:
data['A'] = data['A'].astype(np.float32)
data['B'] = data['B'].astype(np.float32)
data['real_B_enc'] = data['real_B_enc'].astype(np.int64)
data['hint_B'] = data['hint_B'].astype(np.float32)
data['mask_B'] = data['mask_B'].astype(np.float32)
return data
class ColorPostprocess:
"""
Transform images from [0, 1] to [0, 255]
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Image in range of 0-255.
"""
def __init__(self, type: type = np.uint8):
self.type = type
def __call__(self, img: np.ndarray):
img = np.transpose(img, (1, 2, 0))
img = np.clip(img, 0, 1) * 255
img = img.astype(self.type)
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册