diff --git a/demo/style_transfer/candy.jpg b/demo/style_transfer/candy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f40e5a33e93329581baaa2ba564e02ce25615cbe Binary files /dev/null and b/demo/style_transfer/candy.jpg differ diff --git a/demo/style_transfer/predict.py b/demo/style_transfer/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..8c80bfa524417b82aee6399ed579baed5bb3801d --- /dev/null +++ b/demo/style_transfer/predict.py @@ -0,0 +1,9 @@ +import paddle +import paddlehub as hub + +if __name__ == '__main__': + place = paddle.CUDAPlace(0) + paddle.disable_static() + model = hub.Module(name='msgnet') + model.eval() + result = model.predict("venice-boat.jpg", "candy.jpg") diff --git a/demo/style_transfer/train.py b/demo/style_transfer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4183bd7d3b85f9e706fc3f96b6a54946f55e20 --- /dev/null +++ b/demo/style_transfer/train.py @@ -0,0 +1,17 @@ +import paddle +import paddlehub as hub + +from paddlehub.finetune.trainer import Trainer +from paddlehub.datasets.styletransfer import StyleTransferData +from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType + +if __name__ == "__main__": + place = paddle.CUDAPlace(0) + paddle.disable_static() + model = hub.Module(name='msgnet') + transform = Compose([Resize((256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32')) + styledata = StyleTransferData(transform) + model.train() + optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) + trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') + trainer.train(styledata, epochs=5, batch_size=1, eval_dataset=styledata, log_interval=1, save_interval=1) diff --git a/demo/style_transfer/venice-boat.jpg b/demo/style_transfer/venice-boat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cfa6217a41f18d0a335d241992f6ca9991483da7 Binary files /dev/null and b/demo/style_transfer/venice-boat.jpg differ diff --git a/hub_module/modules/image/style_transfer/msgnet/module.py b/hub_module/modules/image/style_transfer/msgnet/module.py new file mode 100644 index 0000000000000000000000000000000000000000..590b432f141bb168d85f6987003abbd5f59c668a --- /dev/null +++ b/hub_module/modules/image/style_transfer/msgnet/module.py @@ -0,0 +1,353 @@ +import os + +import paddle +import paddle.nn as nn +import numpy as np +import paddle.nn.functional as F + +from paddlehub.module.module import moduleinfo +from paddlehub.process.transforms import Compose, Resize, CenterCrop, SetType +from paddlehub.module.cv_module import StyleTransferModule + + +class GramMatrix(nn.Layer): + """Calculate gram matrix""" + def forward(self, y): + (b, ch, h, w) = y.size() + features = y.reshape((b, ch, w * h)) + features_t = features.transpose((0, 2, 1)) + gram = features.bmm(features_t) / (ch * h * w) + return gram + + +class ConvLayer(nn.Layer): + """Basic conv layer with reflection padding layer""" + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int): + super(ConvLayer, self).__init__() + pad = int(np.floor(kernel_size / 2)) + self.reflection_pad = nn.ReflectionPad2d([pad, pad, pad, pad]) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) + + def forward(self, x: paddle.Tensor): + out = self.reflection_pad(x) + out = self.conv2d(out) + return out + + +class UpsampleConvLayer(nn.Layer): + """ + Upsamples the input and then does a convolution. This method gives better results compared to ConvTranspose2d. + ref: http://distill.pub/2016/deconv-checkerboard/ + + Args: + in_channels(int): Number of input channels. + out_channels(int): Number of output channels. + kernel_size(int): Number of kernel size. + stride(int): Number of stride. + upsample(int): Scale factor for upsample layer, default is None. + + Return: + img(paddle.Tensor): UpsampleConvLayer output. + """ + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, upsample=None): + super(UpsampleConvLayer, self).__init__() + self.upsample = upsample + if upsample: + self.upsample_layer = nn.UpSample(scale_factor=upsample) + self.pad = int(np.floor(kernel_size / 2)) + if self.pad != 0: + self.reflection_pad = nn.ReflectionPad2d([self.pad, self.pad, self.pad, self.pad]) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) + + def forward(self, x): + if self.upsample: + x = self.upsample_layer(x) + if self.pad != 0: + x = self.reflection_pad(x) + out = self.conv2d(x) + return out + + +class Bottleneck(nn.Layer): + """ Pre-activation residual block + Identity Mapping in Deep Residual Networks + ref https://arxiv.org/abs/1603.05027 + + Args: + inplanes(int): Number of input channels. + planes(int): Number of output channels. + stride(int): Number of stride. + downsample(int): Scale factor for downsample layer, default is None. + norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d. + + Return: + img(paddle.Tensor): Bottleneck output. + """ + def __init__(self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: int = None, + norm_layer: nn.Layer = nn.BatchNorm2d): + super(Bottleneck, self).__init__() + self.expansion = 4 + self.downsample = downsample + if self.downsample is not None: + self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride) + + conv_block = (norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1), + norm_layer(planes), nn.ReLU(), ConvLayer(planes, planes, kernel_size=3, stride=stride), + norm_layer(planes), nn.ReLU(), nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + stride=1)) + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x: paddle.Tensor): + if self.downsample is not None: + residual = self.residual_layer(x) + else: + residual = x + m = self.conv_block(x) + return residual + self.conv_block(x) + + +class UpBottleneck(nn.Layer): + """ Up-sample residual block (from MSG-Net paper) + Enables passing identity all the way through the generator + ref https://arxiv.org/abs/1703.06953 + + Args: + inplanes(int): Number of input channels. + planes(int): Number of output channels. + stride(int): Number of stride, default is 2. + norm_layer(nn.Layer): Batch norm layer, default is nn.BatchNorm2d. + + Return: + img(paddle.Tensor): UpBottleneck output. + """ + def __init__(self, inplanes: int, planes: int, stride: int = 2, norm_layer: nn.Layer = nn.BatchNorm2d): + super(UpBottleneck, self).__init__() + self.expansion = 4 + self.residual_layer = UpsampleConvLayer(inplanes, + planes * self.expansion, + kernel_size=1, + stride=1, + upsample=stride) + conv_block = [] + conv_block += [norm_layer(inplanes), nn.ReLU(), nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)] + + conv_block += [ + norm_layer(planes), + nn.ReLU(), + UpsampleConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride) + ] + + conv_block += [ + norm_layer(planes), + nn.ReLU(), + nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x: paddle.Tensor): + return self.residual_layer(x) + self.conv_block(x) + + +class Inspiration(nn.Layer): + """ Inspiration Layer (from MSG-Net paper) + tuning the featuremap with target Gram Matrix + ref https://arxiv.org/abs/1703.06953 + + Args: + C(int): Number of input channels. + B(int): B is equal to 1 or input mini_batch, default is 1. + + Return: + img(paddle.Tensor): UpBottleneck output. + """ + def __init__(self, C: int, B: int = 1): + super(Inspiration, self).__init__() + + self.weight = self.weight = paddle.create_parameter(shape=[1, C, C], dtype='float32') + # non-parameter buffer + self.G = paddle.to_tensor(np.random.rand(B, C, C)) + self.C = C + + def setTarget(self, target: paddle.Tensor): + self.G = target + + def forward(self, X: paddle.Tensor): + # input X is a 3D feature map + self.P = paddle.bmm(self.weight.expand_as(self.G), self.G) + + x = paddle.bmm( + self.P.transpose((0, 2, 1)).expand((X.shape[0], self.C, self.C)), X.reshape( + (X.shape[0], X.shape[1], -1))).reshape(X.shape) + return x + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + 'N x ' + str(self.C) + ')' + + +class Vgg16(nn.Layer): + """ First four layers from Vgg16.""" + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + checkpoint = os.path.join(self.directory, 'vgg16.pdparams') + if not os.path.exists(checkpoint): + os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/vgg_paddle.pdparams -O ' + + checkpoint) + model_dict = paddle.load(checkpoint)[0] + self.set_dict(model_dict) + print("load pretrained vgg16 checkpoint success") + + def forward(self, X): + h = F.relu(self.conv1_1(X)) + h = F.relu(self.conv1_2(h)) + relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + relu3_3 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + relu4_3 = h + + return [relu1_2, relu2_2, relu3_3, relu4_3] + + +@moduleinfo( + name="msgnet", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="Msgnet is a image colorization style transfer model, this module is trained with COCO2014 dataset.", + version="1.0.0", + meta=StyleTransferModule) +class MSGNet(nn.Layer): + """ MSGNet (from MSG-Net paper) + Enables passing identity all the way through the generator + ref https://arxiv.org/abs/1703.06953 + + Args: + input_nc(int): Number of input channels, default is 3. + output_nc(int): Number of output channels, default is 3. + ngf(int): Number of input channel for middle layer, default is 128. + n_blocks(int): Block number, default is 6. + norm_layer(nn.Layer): Batch norm layer, default is nn.InstanceNorm2d. + load_checkpoint(str): Pretrained checkpoint path, default is None. + + Return: + img(paddle.Tensor): MSGNet output. + """ + def __init__(self, + input_nc=3, + output_nc=3, + ngf=128, + n_blocks=6, + norm_layer=nn.InstanceNorm2d, + load_checkpoint=None): + super(MSGNet, self).__init__() + self.gram = GramMatrix() + block = Bottleneck + upblock = UpBottleneck + expansion = 4 + + model1 = [ + ConvLayer(input_nc, 64, kernel_size=7, stride=1), + norm_layer(64), + nn.ReLU(), + block(64, 32, 2, 1, norm_layer), + block(32 * expansion, ngf, 2, 1, norm_layer) + ] + + self.model1 = nn.Sequential(*tuple(model1)) + + model = [] + model += model1 + + self.ins = Inspiration(ngf * expansion) + model.append(self.ins) + for i in range(n_blocks): + model += [block(ngf * expansion, ngf, 1, None, norm_layer)] + + model += [ + upblock(ngf * expansion, 32, 2, norm_layer), + upblock(32 * expansion, 16, 2, norm_layer), + norm_layer(16 * expansion), + nn.ReLU(), + ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1) + ] + model = tuple(model) + self.model = nn.Sequential(*model) + + if load_checkpoint is not None: + model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, 'style_paddle.pdparams') + if not os.path.exists(checkpoint): + os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' + + checkpoint) + model_dict = paddle.load(checkpoint)[0] + model_dict_clone = model_dict.copy() + for key, value in model_dict_clone.items(): + if key.endswith(("scale")): + name = key.rsplit('.', 1)[0] + '.bias' + model_dict[name] = paddle.zeros(shape=model_dict[name].shape, dtype='float32') + model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32') + self.set_dict(model_dict) + print("load pretrained checkpoint success") + + self._vgg = None + + def transform(self, path: str): + transform = Compose([Resize( + (256, 256), interp='LINEAR'), CenterCrop(crop_size=256)], SetType(datatype='float32')) + return transform(path) + + def setTarget(self, Xs: paddle.Tensor): + """Calculate feature gram matrix""" + F = self.model1(Xs) + G = self.gram(F) + self.ins.setTarget(G) + + def getFeature(self, input: paddle.Tensor): + if not self._vgg: + self._vgg = Vgg16() + return self._vgg(input) + + def forward(self, input: paddle.Tensor): + return self.model(input) diff --git a/paddlehub/datasets/styletransfer.py b/paddlehub/datasets/styletransfer.py new file mode 100644 index 0000000000000000000000000000000000000000..aef5a3b4d7392da8e6d9c8c1c6f903f037391e5d --- /dev/null +++ b/paddlehub/datasets/styletransfer.py @@ -0,0 +1,58 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable + +import paddle +from paddlehub.process.functional import get_img_file +from paddlehub.env import DATA_HOME + + +class StyleTransferData(paddle.io.Dataset): + """ + Dataset for Style transfer. + + Args: + transform(callmethod) : The method of preprocess images. + mode(str): The mode for preparing dataset. + + Returns: + DataSet: An iterable object for data iterating + """ + def __init__(self, transform: Callable, mode: str = 'train'): + self.mode = mode + self.transform = transform + + if self.mode == 'train': + self.file = 'train' + elif self.mode == 'test': + self.file = 'test' + self.file = os.path.join(DATA_HOME, 'minicoco', self.file) + self.style_file = os.path.join(DATA_HOME, 'minicoco', '21styles') + self.data = get_img_file(self.file) + self.style = get_img_file(self.style_file) + + def __getitem__(self, idx: int): + + img_path = self.data[idx] + im = self.transform(img_path) + style_idx = idx % len(self.style) + style_path = self.style[style_idx] + style = self.transform(style_path) + return im, style + + def __len__(self): + return len(self.data) diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index d62b1622d4bb32ea5247f3f7602b87eac0d8095f..ac462368c37085de94e7964e885a34c33463767d 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -103,11 +103,11 @@ class ImageColorizeModule(RunModule, ImageServing): def training_step(self, batch: int, batch_idx: int) -> dict: ''' One step for training, which should be called as forward computation. - + Args: batch(list[paddle.Tensor]): The one batch data, which contains images and labels. batch_idx(int): The index of batch. - + Returns: results(dict) : The model outputs, such as loss and metrics. ''' @@ -116,46 +116,48 @@ class ImageColorizeModule(RunModule, ImageServing): def validation_step(self, batch: int, batch_idx: int) -> dict: ''' One step for validation, which should be called as forward computation. - + Args: batch(list[paddle.Tensor]): The one batch data, which contains images and labels. batch_idx(int): The index of batch. - + Returns: results(dict) : The model outputs, such as metrics. ''' out_class, out_reg = self(batch[0], batch[1], batch[2]) - + criterionCE = nn.loss.CrossEntropyLoss() loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss = loss_ce + loss_G_L1_reg - + visual_ret = OrderedDict() psnrs = [] lab2rgb = ConvertColorSpace(mode='LAB2RGB') process = ColorPostprocess() + for i in range(batch[0].numpy().shape[0]): real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i] visual_ret['real'] = process(real) fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = process(fake) - mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2) + mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnrs.append(psnr_value) psnr = paddle.to_variable(np.array(psnrs)) + return {'loss': loss, 'metrics': {'psnr': psnr}} def predict(self, images: str, visualization: bool = True, save_path: str = 'result'): ''' Colorize images - + Args: images(str) : Images path to be colorized. visualization(bool): Whether to save colorized images. save_path(str) : Path to save colorized images. - + Returns: results(list[dict]) : The prediction result of each input image ''' @@ -177,7 +179,7 @@ class ImageColorizeModule(RunModule, ImageServing): visual_ret['real'] = resize(process(real)) fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] visual_ret['fake_reg'] = resize(process(fake)) - + if visualization: fake_name = "fake_" + str(time.time()) + ".png" if not os.path.exists(save_path): @@ -185,8 +187,8 @@ class ImageColorizeModule(RunModule, ImageServing): fake_path = os.path.join(save_path, fake_name) visual_gray = Image.fromarray(visual_ret['fake_reg']) visual_gray.save(fake_path) - - mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2) + + mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2) psnr_value = 20 * np.log10(255. / np.sqrt(mse)) result.append(visual_ret) return result