未验证 提交 96096612 编写于 作者: H haoyuying 提交者: GitHub

Colorize (#897)

上级 5889f7cf
......@@ -2,9 +2,8 @@ import paddle
import paddlehub as hub
import paddle.nn as nn
if __name__ == '__main__':
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
model = hub.Module(name='user_guided_colorization')
model.eval()
result = model.predict(images='sea.jpg')
\ No newline at end of file
result = model.predict(images='house.png')
......@@ -3,18 +3,24 @@ import paddlehub as hub
import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__':
is_train = True
paddle.disable_static()
model = hub.Module(directory='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True)
color_set = Colorizedataset(transform=transform, mode=is_train)
model = hub.Module(name='user_guided_colorization')
transform = Compose([
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train:
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1)
trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import numpy
import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule
......@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained model success")
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
stay_rgb=True,
is_permute=False)
else:
transform = Compose([
Resize((256, 256), interp="RANDOM"),
Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train)
],
stay_rgb=True)
stay_rgb=True,
is_permute=False)
return transform(images)
def forward(self,
......
......@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME
from typing import Callable
class Colorizedataset(paddle.io.Dataset):
"""
Dataset for colorization.
......@@ -34,14 +35,12 @@ class Colorizedataset(paddle.io.Dataset):
def __init__(self, transform: Callable, mode: str = 'train'):
self.mode = mode
self.transform = transform
if self.mode == 'train':
self.file = 'train'
elif self.mode == 'test':
self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file)
......@@ -51,4 +50,4 @@ class Colorizedataset(paddle.io.Dataset):
return im['A'], im['hint_B'], im['mask_B'], im['B'], im['real_B_enc']
def __len__(self):
return len(self.data)
\ No newline at end of file
return len(self.data)
......@@ -18,6 +18,7 @@ import os
from typing import List
from collections import OrderedDict
import cv2
import numpy as np
import paddle
import paddle.nn as nn
......@@ -27,6 +28,7 @@ from PIL import Image
from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize
from paddlehub.process.functional import subtract_imagenet_mean_batch, gram_matrix
class ImageServing(object):
......@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret)
return result
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
def predict(self, origin_path: str, style_path: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style)
output = self(content)
output = paddle.clip(output[0].transpose((1, 2, 0)), 0, 255).numpy()
if visualization:
output = output.astype(np.uint8)
style_name = "style_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
path = os.path.join(save_path, style_name)
cv2.imwrite(path, output)
return output
......@@ -15,6 +15,7 @@
import os
import cv2
import paddle
import numpy as np
from PIL import Image, ImageEnhance
......@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list:
if not is_image_file(filename):
continue
img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path)
images.sort()
return images
\ No newline at end of file
return images
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
mean = paddle.to_tensor(mean)
return batch - mean
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
"""Get gram matrix"""
b, ch, h, w = data.shape
features = data.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
......@@ -24,15 +24,16 @@ from paddlehub.process.functional import *
class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False):
def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
'must be equal or larger than 1!')
self.transforms = transforms
self.to_rgb = to_rgb
self.stay_rgb = stay_rgb
self.is_permute = is_permute
def __call__(self, im):
if isinstance(im, str):
......@@ -45,15 +46,16 @@ class Compose:
for op in self.transforms:
im = op(im)
if not self.stay_rgb:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if self.is_permute:
im = permute(im)
return im
class RandomHorizontalFlip:
def __init__(self, prob=0.5):
self.prob = prob
......@@ -239,8 +241,13 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(
im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=self.im_padding_value)
im = cv2.copyMakeBorder(im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if crop_height > 0 and crop_width > 0:
h_off = np.random.randint(img_height - crop_height + 1)
......@@ -295,13 +302,12 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy
dsize = (nw, nh)
im = cv2.warpAffine(
im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
im = cv2.warpAffine(im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
return im
......@@ -403,14 +409,14 @@ class RandomDistort:
return im
class ConvertColorSpace:
"""
Convert color space from RGB to LAB or from LAB to RGB.
Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return:
img(np.ndarray): converted image.
"""
......@@ -429,7 +435,7 @@ class ConvertColorSpace:
"""
mask = (rgb > 0.04045)
np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055) ** 2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
......@@ -490,7 +496,7 @@ class ConvertColorSpace:
rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs
mask = (rgb > .0031308).astype(np.float32)
np.seterr(invalid='ignore')
out = (1.055 * (rgb ** (1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = np.nan_to_num(out)
return out
......@@ -511,7 +517,7 @@ class ConvertColorSpace:
out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1)
mask = (out > .2068966).astype(np.float32)
np.seterr(invalid='ignore')
out = (out ** 3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = np.nan_to_num(out)
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
out = out * sc
......@@ -546,27 +552,27 @@ class ConvertColorSpace:
class ColorizeHint:
"""Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args:
percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
Return:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent
self.num_points = num_points
self.samp = samp
self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ]
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data
self.hint = hint
self.mask = mask
......@@ -577,7 +583,7 @@ class ColorizeHint:
while cont_cond:
if self.num_points is None: # draw from geometric
# embed()
cont_cond = np.random.rand() < (1 - self.percent)
cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points
cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met
......@@ -593,9 +599,11 @@ class ColorizeHint:
# add color point
if self.use_avg:
# embed()
hint[nn, :, h:h + P, w:w + P] = np.mean(
np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), axis=1, keepdims=True).reshape(
1, C, 1, 1)
hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
axis=2,
keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1
......@@ -609,10 +617,10 @@ class ColorizeHint:
class SqueezeAxis:
"""
Squeeze the specific axis when it equal to 1.
Args:
axis(int): Which axis should be squeezed.
"""
def __init__(self, axis: int):
self.axis = axis
......@@ -628,7 +636,7 @@ class SqueezeAxis:
class ColorizePreprocess:
"""Prepare dataset for image Colorization.
Args:
ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration.
......@@ -636,13 +644,14 @@ class ColorizePreprocess:
samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not.
Return:
data(dict):The preprocessed data for colorization.
"""
def __init__(self, ab_thresh: float = 0.,
p: float = .125,
def __init__(self,
ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None,
samp: str = 'normal',
use_avg: bool = True,
......@@ -668,11 +677,14 @@ class ColorizePreprocess:
"""
data = {}
A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0, ], :, :]
data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),axis=1)
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :]
......@@ -698,10 +710,10 @@ class ColorizePreprocess:
class ColorPostprocess:
"""
Transform images from [0, 1] to [0, 255]
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Image in range of 0-255.
"""
......@@ -713,3 +725,41 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255
img = img.astype(self.type)
return img
class CenterCrop:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def __init__(self, crop_size: int):
self.crop_size = crop_size
def __call__(self, img: np.ndarray):
img_width, img_height, chanel = img.shape
crop_top = int((img_height - self.crop_size) / 2.)
crop_left = int((img_width - self.crop_size) / 2.)
return img[crop_left:crop_left + self.crop_size, crop_top:crop_top + self.crop_size, :]
class SetType:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def __init__(self, datatype: type = 'float32'):
self.type = datatype
def __call__(self, img: np.ndarray):
img = img.astype(self.type)
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册