未验证 提交 96096612 编写于 作者: H haoyuying 提交者: GitHub

Colorize (#897)

上级 5889f7cf
...@@ -2,9 +2,8 @@ import paddle ...@@ -2,9 +2,8 @@ import paddle
import paddlehub as hub import paddlehub as hub
import paddle.nn as nn import paddle.nn as nn
if __name__ == '__main__': if __name__ == '__main__':
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
model.eval() model.eval()
result = model.predict(images='sea.jpg') result = model.predict(images='house.png')
\ No newline at end of file
...@@ -6,15 +6,21 @@ from paddlehub.finetune.trainer import Trainer ...@@ -6,15 +6,21 @@ from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.colorizedataset import Colorizedataset from paddlehub.datasets.colorizedataset import Colorizedataset
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
if __name__ == '__main__': if __name__ == '__main__':
is_train = True is_train = True
paddle.disable_static() paddle.disable_static()
model = hub.Module(directory='user_guided_colorization') model = hub.Module(name='user_guided_colorization')
transform = Compose([Resize((256,256),interp="RANDOM"),RandomPaddingCrop(crop_size=176), ConvertColorSpace(mode='RGB2LAB'), ColorizePreprocess(ab_thresh=0, p=1)], stay_rgb=True) transform = Compose([
color_set = Colorizedataset(transform=transform, mode=is_train) Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train),
],
stay_rgb=True,
is_permute=False)
color_set = Colorizedataset(transform=transform, mode='train')
if is_train: if is_train:
model.train() model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(color_set, epochs=3, batch_size=1, eval_dataset=color_set, save_interval=1) trainer.train(color_set, epochs=101, batch_size=5, eval_dataset=color_set, log_interval=10, save_interval=10)
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
import numpy import numpy
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2d, ConvTranspose2d from paddle.nn import Conv2d, ConvTranspose2d
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess from paddlehub.process.transforms import Compose, Resize, RandomPaddingCrop, ConvertColorSpace, ColorizePreprocess
from paddlehub.module.cv_module import ImageColorizeModule from paddlehub.module.cv_module import ImageColorizeModule
...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer): ...@@ -178,24 +179,31 @@ class UserGuidedColorization(nn.Layer):
if load_checkpoint is not None: if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0] model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained model success") print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transforms(self, images: str, is_train: bool = True) -> callable: def transforms(self, images: str, is_train: bool = True) -> callable:
if is_train: if is_train:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
RandomPaddingCrop(crop_size=176), RandomPaddingCrop(crop_size=176),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
else: else:
transform = Compose([ transform = Compose([
Resize((256, 256), interp="RANDOM"), Resize((256, 256), interp='NEAREST'),
ConvertColorSpace(mode='RGB2LAB'), ConvertColorSpace(mode='RGB2LAB'),
ColorizePreprocess(ab_thresh=0, is_train=is_train) ColorizePreprocess(ab_thresh=0, is_train=is_train)
], ],
stay_rgb=True) stay_rgb=True,
is_permute=False)
return transform(images) return transform(images)
def forward(self, def forward(self,
......
...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file ...@@ -22,6 +22,7 @@ from paddlehub.process.functional import get_img_file
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
from typing import Callable from typing import Callable
class Colorizedataset(paddle.io.Dataset): class Colorizedataset(paddle.io.Dataset):
""" """
Dataset for colorization. Dataset for colorization.
...@@ -39,8 +40,6 @@ class Colorizedataset(paddle.io.Dataset): ...@@ -39,8 +40,6 @@ class Colorizedataset(paddle.io.Dataset):
self.file = 'train' self.file = 'train'
elif self.mode == 'test': elif self.mode == 'test':
self.file = 'test' self.file = 'test'
else:
self.file = 'validation'
self.file = os.path.join(DATA_HOME, 'canvas', self.file) self.file = os.path.join(DATA_HOME, 'canvas', self.file)
self.data = get_img_file(self.file) self.data = get_img_file(self.file)
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from typing import List from typing import List
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -27,6 +28,7 @@ from PIL import Image ...@@ -27,6 +28,7 @@ from PIL import Image
from paddlehub.module.module import serving, RunModule from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2 from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize
from paddlehub.process.functional import subtract_imagenet_mean_batch, gram_matrix
class ImageServing(object): class ImageServing(object):
...@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -192,3 +194,87 @@ class ImageColorizeModule(RunModule, ImageServing):
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret) result.append(visual_ret)
return result return result
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
mse_loss = nn.MSELoss()
N, C, H, W = batch[0].shape
batch[1] = batch[1][0].unsqueeze(0)
self.setTarget(batch[1])
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
loss = content_loss + style_loss
return {'loss': loss, 'metrics': {'content gap': content_loss, 'style gap': style_loss}}
def predict(self, origin_path: str, style_path: str, visualization: bool = True, save_path: str = 'result'):
'''
Colorize images
Args:
origin_path(str): Content image path .
style_path(str): Style image path.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
style = style.unsqueeze(0)
self.setTarget(style)
output = self(content)
output = paddle.clip(output[0].transpose((1, 2, 0)), 0, 255).numpy()
if visualization:
output = output.astype(np.uint8)
style_name = "style_" + str(time.time()) + ".png"
if not os.path.exists(save_path):
os.mkdir(save_path)
path = os.path.join(save_path, style_name)
cv2.imwrite(path, output)
return output
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import cv2 import cv2
import paddle
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
...@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list: ...@@ -114,7 +115,25 @@ def get_img_file(dir_name: str) -> list:
if not is_image_file(filename): if not is_image_file(filename):
continue continue
img_path = os.path.join(parent, filename) img_path = os.path.join(parent, filename)
print(img_path)
images.append(img_path) images.append(img_path)
images.sort() images.sort()
return images return images
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
mean = paddle.to_tensor(mean)
return batch - mean
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
"""Get gram matrix"""
b, ch, h, w = data.shape
features = data.reshape((b, ch, w * h))
features_t = features.transpose((0, 2, 1))
gram = features.bmm(features_t) / (ch * h * w)
return gram
...@@ -24,7 +24,7 @@ from paddlehub.process.functional import * ...@@ -24,7 +24,7 @@ from paddlehub.process.functional import *
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False): def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
...@@ -33,6 +33,7 @@ class Compose: ...@@ -33,6 +33,7 @@ class Compose:
self.transforms = transforms self.transforms = transforms
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.stay_rgb = stay_rgb self.stay_rgb = stay_rgb
self.is_permute = is_permute
def __call__(self, im): def __call__(self, im):
if isinstance(im, str): if isinstance(im, str):
...@@ -47,13 +48,14 @@ class Compose: ...@@ -47,13 +48,14 @@ class Compose:
im = op(im) im = op(im)
if not self.stay_rgb: if not self.stay_rgb:
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if self.is_permute:
im = permute(im) im = permute(im)
return im return im
class RandomHorizontalFlip: class RandomHorizontalFlip:
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -239,8 +241,13 @@ class RandomPaddingCrop: ...@@ -239,8 +241,13 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder( im = cv2.copyMakeBorder(im,
im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=self.im_padding_value) 0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if crop_height > 0 and crop_width > 0: if crop_height > 0 and crop_width > 0:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
...@@ -295,8 +302,7 @@ class RandomRotation: ...@@ -295,8 +302,7 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine( im = cv2.warpAffine(im,
im,
r, r,
dsize=dsize, dsize=dsize,
flags=cv2.INTER_LINEAR, flags=cv2.INTER_LINEAR,
...@@ -429,7 +435,7 @@ class ConvertColorSpace: ...@@ -429,7 +435,7 @@ class ConvertColorSpace:
""" """
mask = (rgb > 0.04045) mask = (rgb > 0.04045)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055) ** 2.4) * mask + rgb / 12.92 * (1 - mask) rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb) rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :] x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :] y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
...@@ -490,7 +496,7 @@ class ConvertColorSpace: ...@@ -490,7 +496,7 @@ class ConvertColorSpace:
rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs
mask = (rgb > .0031308).astype(np.float32) mask = (rgb > .0031308).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (1.055 * (rgb ** (1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) out = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
return out return out
...@@ -511,7 +517,7 @@ class ConvertColorSpace: ...@@ -511,7 +517,7 @@ class ConvertColorSpace:
out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1) out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1)
mask = (out > .2068966).astype(np.float32) mask = (out > .2068966).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (out ** 3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None] sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
out = out * sc out = out * sc
...@@ -566,7 +572,7 @@ class ColorizeHint: ...@@ -566,7 +572,7 @@ class ColorizeHint:
self.use_avg = use_avg self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ] sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.data = data self.data = data
self.hint = hint self.hint = hint
self.mask = mask self.mask = mask
...@@ -577,7 +583,7 @@ class ColorizeHint: ...@@ -577,7 +583,7 @@ class ColorizeHint:
while cont_cond: while cont_cond:
if self.num_points is None: # draw from geometric if self.num_points is None: # draw from geometric
# embed() # embed()
cont_cond = np.random.rand() < (1 - self.percent) cont_cond = np.random.rand() > (1 - self.percent)
else: # add certain number of points else: # add certain number of points
cont_cond = pp < self.num_points cont_cond = pp < self.num_points
if not cont_cond: # skip out of loop if condition not met if not cont_cond: # skip out of loop if condition not met
...@@ -593,9 +599,11 @@ class ColorizeHint: ...@@ -593,9 +599,11 @@ class ColorizeHint:
# add color point # add color point
if self.use_avg: if self.use_avg:
# embed() # embed()
hint[nn, :, h:h + P, w:w + P] = np.mean( hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), axis=1, keepdims=True).reshape( axis=2,
1, C, 1, 1) keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else: else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1 mask[nn, :, h:h + P, w:w + P] = 1
...@@ -641,8 +649,9 @@ class ColorizePreprocess: ...@@ -641,8 +649,9 @@ class ColorizePreprocess:
data(dict):The preprocessed data for colorization. data(dict):The preprocessed data for colorization.
""" """
def __init__(self, ab_thresh: float = 0., def __init__(self,
p: float = .125, ab_thresh: float = 0.,
p: float = 0.,
num_points: int = None, num_points: int = None,
samp: str = 'normal', samp: str = 'normal',
use_avg: bool = True, use_avg: bool = True,
...@@ -668,11 +677,14 @@ class ColorizePreprocess: ...@@ -668,11 +677,14 @@ class ColorizePreprocess:
""" """
data = {} data = {}
A = 2 * 110 / 10 + 1 A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0, ], :, :] data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :] data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110 thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),axis=1) mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh) mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :] data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :]
...@@ -713,3 +725,41 @@ class ColorPostprocess: ...@@ -713,3 +725,41 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255 img = np.clip(img, 0, 1) * 255
img = img.astype(self.type) img = img.astype(self.type)
return img return img
class CenterCrop:
"""
Crop the middle part of the image to the specified size.
Args:
crop_size(int): Crop size.
Return:
img(np.ndarray): Croped image.
"""
def __init__(self, crop_size: int):
self.crop_size = crop_size
def __call__(self, img: np.ndarray):
img_width, img_height, chanel = img.shape
crop_top = int((img_height - self.crop_size) / 2.)
crop_left = int((img_width - self.crop_size) / 2.)
return img[crop_left:crop_left + self.crop_size, crop_top:crop_top + self.crop_size, :]
class SetType:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def __init__(self, datatype: type = 'float32'):
self.type = datatype
def __call__(self, img: np.ndarray):
img = img.astype(self.type)
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册