diff --git a/configs/lapstyle_draft.yaml b/configs/lapstyle_draft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4864aacd7f03208fb62e363f376865c40c8ce2f --- /dev/null +++ b/configs/lapstyle_draft.yaml @@ -0,0 +1,67 @@ +total_iters: 30000 +output_dir: output_dir +checkpoints_dir: checkpoints +min_max: + (0., 1.) + +model: + name: LapStyleModel + generator_encode: + name: Encoder + generator_decode: + name: DecoderNet + calc_style_emd_loss: + name: CalcStyleEmdLoss + calc_content_relt_loss: + name: CalcContentReltLoss + calc_content_loss: + name: CalcContentLoss + calc_style_loss: + name: CalcStyleLoss + content_layers: ['r11', 'r21', 'r31', 'r41', 'r51'] + style_layers: ['r11', 'r21', 'r31', 'r41', 'r51'] + content_weight: 1.0 + style_weight: 3.0 + + +dataset: + train: + name: LapStyleDataset + content_root: data/coco/train2017/ + style_root: data/starrynew.png + load_size: 136 + crop_size: 128 + num_workers: 16 + batch_size: 5 + test: + name: LapStyleDataset + content_root: data/coco/test2017/ + style_root: data/starrynew.png + load_size: 136 + crop_size: 128 + num_workers: 0 + batch_size: 1 + +lr_scheduler: + name: NonLinearDecay + learning_rate: 1e-4 + lr_decay: 5e-5 + +optimizer: + optimG: + name: Adam + net_names: + - net_dec + beta1: 0.9 + beta2: 0.999 + +validate: + interval: 5000 + save_img: false + +log_config: + interval: 10 + visiual_interval: 5000 + +snapshot_config: + interval: 5000 diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index ce3d0cc143ef200e8464444f71d81dbbf56b06d1..843fd74ee6892d505c487470861f7ee79cf267dc 100755 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -23,3 +23,4 @@ from .wav2lip_dataset import Wav2LipDataset from .starganv2_dataset import StarGANv2Dataset from .edvr_dataset import REDSDataset from .firstorder_dataset import FirstOrderDataset +from .lapstyle_dataset import LapStyleDataset diff --git a/ppgan/datasets/lapstyle_dataset.py b/ppgan/datasets/lapstyle_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe92dc4c7756630b822c008a95a353f5888249b --- /dev/null +++ b/ppgan/datasets/lapstyle_dataset.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import numpy as np +from PIL import Image +import paddle +import paddle.vision.transforms as T +from paddle.io import Dataset + +from .builder import DATASETS + +logger = logging.getLogger(__name__) + + +def data_transform(crop_size): + transform_list = [T.RandomCrop(crop_size)] + return T.Compose(transform_list) + + +@DATASETS.register() +class LapStyleDataset(Dataset): + """ + coco2017 dataset for LapStyle model + """ + def __init__(self, content_root, style_root, load_size, crop_size): + super(LapStyleDataset, self).__init__() + self.content_root = content_root + self.paths = os.listdir(self.content_root) + self.style_root = style_root + self.load_size = load_size + self.crop_size = crop_size + self.transform = data_transform(self.crop_size) + + def __getitem__(self, index): + """Get training sample + + return: + ci: content image with shape [C,W,H], + si: style image with shape [C,W,H], + ci_path: str + """ + path = self.paths[index] + content_img = Image.open(os.path.join(self.content_root, + path)).convert('RGB') + content_img = content_img.resize((self.load_size, self.load_size), + Image.BILINEAR) + content_img = np.array(content_img) + style_img = Image.open(self.style_root).convert('RGB') + style_img = style_img.resize((self.load_size, self.load_size), + Image.BILINEAR) + style_img = np.array(style_img) + content_img = self.transform(content_img) + style_img = self.transform(style_img) + content_img = self.img(content_img) + style_img = self.img(style_img) + return {'ci': content_img, 'si': style_img, 'ci_path': path} + + def img(self, img): + """make image with [0,255] and HWC to [0,1] and CHW + + return: + img: image with shape [3,W,H] and value [0, 1]. + """ + # [0,255] to [0,1] + img = img.astype(np.float32) / 255. + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + # HWC to CHW + img = np.transpose(img, (2, 0, 1)).astype('float32') + return img + + def __len__(self): + return len(self.paths) + + def name(self): + return 'LapStyleDataset' diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index b49b8fc05c152afc3c5aa06112e622b548975047..e7cf0cbc103601ac4e7ed4624ca98d602b31e7c7 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -347,7 +347,10 @@ class Trainer: dataformats="HWC" if image_num == 1 else "NCHW") else: if self.cfg.is_train: - msg = 'epoch%.3d_' % self.current_epoch + if self.by_epoch: + msg = 'epoch%.3d_' % self.current_epoch + else: + msg = 'iter%.3d_' % self.current_iter else: msg = '' makedirs(os.path.join(self.output_dir, results_dir)) diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 43e727c9a904fc9cefaee8d5aad0ac4f08fda207..9c84a1eb1836a6530e2063a44580a667c013b4b8 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -29,3 +29,4 @@ from .wav2lip_hq_model import Wav2LipModelHq from .starganv2_model import StarGANv2Model from .edvr_model import EDVRModel from .firstorder_model import FirstOrderModel +from .lapstyle_model import LapStyleModel diff --git a/ppgan/models/criterions/__init__.py b/ppgan/models/criterions/__init__.py index a172f37c3cacea4339aabee7b9aaf1fc40b57083..cd760e71ebd002b73146845f9279fe584a2c5d4e 100644 --- a/ppgan/models/criterions/__init__.py +++ b/ppgan/models/criterions/__init__.py @@ -1,5 +1,7 @@ from .gan_loss import GANLoss from .perceptual_loss import PerceptualLoss -from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss +from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \ + CalcStyleEmdLoss, CalcContentReltLoss, \ + CalcContentLoss, CalcStyleLoss from .builder import build_criterion diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index f496ef7cf90da5049ab180a9e0a12d50c8acf82f..11df70f3a38988d5bab8c62703f25c7a5fbdb21e 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +from ..generators.generater_lapstyle import calc_mean_std, mean_variance_norm import paddle import paddle.nn as nn @@ -127,3 +128,109 @@ class BCEWithLogitsLoss(): weights. Default: None. """ return self.loss_weight * self._bce_loss(pred, target) + + +def calc_emd_loss(pred, target): + """calculate emd loss. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + """ + b, _, h, w = pred.shape + pred = pred.reshape([b, -1, w * h]) + pred_norm = paddle.sqrt((pred**2).sum(1).reshape([b, -1, 1])) + pred = pred.transpose([0, 2, 1]) + target_t = target.reshape([b, -1, w * h]) + target_norm = paddle.sqrt((target**2).sum(1).reshape([b, 1, -1])) + similarity = paddle.bmm(pred, target_t) / pred_norm / target_norm + dist = 1. - similarity + return dist + + +@CRITERIONS.register() +class CalcStyleEmdLoss(): + """Calc Style Emd Loss. + """ + def __init__(self): + super(CalcStyleEmdLoss, self).__init__() + + def __call__(self, pred, target): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + """ + CX_M = calc_emd_loss(pred, target) + m1 = CX_M.min(2) + m2 = CX_M.min(1) + m = paddle.concat([m1.mean(), m2.mean()]) + loss_remd = paddle.max(m) + return loss_remd + + +@CRITERIONS.register() +class CalcContentReltLoss(): + """Calc Content Relt Loss. + """ + def __init__(self): + super(CalcContentReltLoss, self).__init__() + + def __call__(self, pred, target): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + """ + dM = 1. + Mx = calc_emd_loss(pred, pred) + Mx = Mx / Mx.sum(1, keepdim=True) + My = calc_emd_loss(target, target) + My = My / My.sum(1, keepdim=True) + loss_content = paddle.abs( + dM * (Mx - My)).mean() * pred.shape[2] * pred.shape[3] + return loss_content + + +@CRITERIONS.register() +class CalcContentLoss(): + """Calc Content Loss. + """ + def __init__(self): + self.mse_loss = nn.MSELoss() + + def __call__(self, pred, target, norm=False): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + norm(Bool): whether use mean_variance_norm for pred and target + """ + if (norm == False): + return self.mse_loss(pred, target) + else: + return self.mse_loss(mean_variance_norm(pred), + mean_variance_norm(target)) + + +@CRITERIONS.register() +class CalcStyleLoss(): + """Calc Style Loss. + """ + def __init__(self): + self.mse_loss = nn.MSELoss() + + def __call__(self, pred, target): + """Forward Function. + + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + """ + pred_mean, pred_std = calc_mean_std(pred) + target_mean, target_std = calc_mean_std(target) + return self.mse_loss(pred_mean, target_mean) + self.mse_loss( + pred_std, target_std) diff --git a/ppgan/models/discriminators/__init__.py b/ppgan/models/discriminators/__init__.py index 0f9d8c9bc6c39362bded26d914f5a48f21964a5b..588241fa89793a942487af8acf6dcdbfa2c7461e 100644 --- a/ppgan/models/discriminators/__init__.py +++ b/ppgan/models/discriminators/__init__.py @@ -22,3 +22,4 @@ from .syncnet import SyncNetColor from .wav2lip_disc_qual import Wav2LipDiscQual from .discriminator_starganv2 import StarGANv2Discriminator from .discriminator_firstorder import FirstOrderDiscriminator +from .discriminator_lapstyle import LapStyleDiscriminator diff --git a/ppgan/models/discriminators/discriminator_lapstyle.py b/ppgan/models/discriminators/discriminator_lapstyle.py new file mode 100644 index 0000000000000000000000000000000000000000..624cfd8695df6f9510b8e0332c6bade0413a4fd5 --- /dev/null +++ b/ppgan/models/discriminators/discriminator_lapstyle.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .builder import DISCRIMINATORS + + +@DISCRIMINATORS.register() +class LapStyleDiscriminator(nn.Layer): + def __init__(self): + super(LapStyleDiscriminator, self).__init__() + num_layer = 3 + num_channel = 32 + self.head = nn.Sequential( + ('conv', + nn.Conv2D(3, num_channel, kernel_size=3, stride=1, padding=1)), + ('norm', nn.BatchNorm2D(num_channel)), + ('LeakyRelu', nn.LeakyReLU(0.2))) + self.body = nn.Sequential() + for i in range(num_layer - 2): + self.body.add_sublayer( + 'conv%d' % (i + 1), + nn.Conv2D(num_channel, + num_channel, + kernel_size=3, + stride=1, + padding=1)) + self.body.add_sublayer('norm%d' % (i + 1), + nn.BatchNorm2D(num_channel)) + self.body.add_sublayer('LeakyRelu%d' % (i + 1), nn.LeakyReLU(0.2)) + self.tail = nn.Conv2D(num_channel, + 1, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = self.head(x) + x = self.body(x) + x = self.tail(x) + return x diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index cb46b8ac8a7d0315d6b1287d08562dfd9f230392..0ad7aca559311b31b6693888deffc98ec8bd1073 100755 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -29,3 +29,4 @@ from .drn import DRNGenerator from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN from .edvr import EDVRNet from .generator_firstorder import FirstOrderGenerator +from .generater_lapstyle import DecoderNet, Encoder diff --git a/ppgan/models/generators/generater_lapstyle.py b/ppgan/models/generators/generater_lapstyle.py new file mode 100644 index 0000000000000000000000000000000000000000..53c0911527687d66ada00ffb2e61566bc1e8a260 --- /dev/null +++ b/ppgan/models/generators/generater_lapstyle.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +from ...utils.download import get_path_from_url + +from .builder import GENERATORS + + +def calc_mean_std(feat, eps=1e-5): + """calculate mean and standard deviation. + + Args: + feat (Tensor): Tensor with shape (N, C, H, W). + eps (float): Default: 1e-5. + + Return: + mean and std of feat + shape: [N, C, 1, 1] + """ + size = feat.shape + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.reshape([N, C, -1]) + feat_var = paddle.var(feat_var, axis=2) + eps + feat_std = paddle.sqrt(feat_var) + feat_std = feat_std.reshape([N, C, 1, 1]) + feat_mean = feat.reshape([N, C, -1]) + feat_mean = paddle.mean(feat_mean, axis=2) + feat_mean = feat_mean.reshape([N, C, 1, 1]) + return feat_mean, feat_std + + +def mean_variance_norm(feat): + """mean_variance_norm. + + Args: + feat (Tensor): Tensor with shape (N, C, H, W). + + Return: + Normalized feat with shape (N, C, H, W) + """ + size = feat.shape + mean, std = calc_mean_std(feat) + normalized_feat = (feat - mean.expand(size)) / std.expand(size) + return normalized_feat + + +def adaptive_instance_normalization(content_feat, style_feat): + """adaptive_instance_normalization. + + Args: + content_feat (Tensor): Tensor with shape (N, C, H, W). + style_feat (Tensor): Tensor with shape (N, C, H, W). + + Return: + Normalized content_feat with shape (N, C, H, W) + """ + assert (content_feat.shape[:2] == style_feat.shape[:2]) + size = content_feat.shape + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + + normalized_feat = (content_feat - + content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class ResnetBlock(nn.Layer): + """Residual block. + + It has a style of: + ---Pad-Conv-ReLU-Pad-Conv-+- + |________________________| + + Args: + dim (int): Channel number of intermediate features. + """ + def __init__(self, dim): + super(ResnetBlock, self).__init__() + self.conv_block = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(dim, dim, (3, 3)), nn.ReLU(), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(dim, dim, (3, 3))) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class ConvBlock(nn.Layer): + """convolution block. + + It has a style of: + ---Pad-Conv-ReLU--- + + Args: + dim1 (int): Channel number of input features. + dim2 (int): Channel number of output features. + """ + def __init__(self, dim1, dim2): + super(ConvBlock, self).__init__() + self.conv_block = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(dim1, dim2, (3, 3)), + nn.ReLU()) + + def forward(self, x): + out = self.conv_block(x) + return out + + +@GENERATORS.register() +class DecoderNet(nn.Layer): + """Decoder of Drafting module. + Paper: + Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality + Artistic Style Transfer. + """ + def __init__(self): + super(DecoderNet, self).__init__() + + self.resblock_41 = ResnetBlock(512) + self.convblock_41 = ConvBlock(512, 256) + self.resblock_31 = ResnetBlock(256) + self.convblock_31 = ConvBlock(256, 128) + + self.convblock_21 = ConvBlock(128, 128) + self.convblock_22 = ConvBlock(128, 64) + + self.convblock_11 = ConvBlock(64, 64) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + self.final_conv = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(64, 3, (3, 3))) + + def forward(self, cF, sF): + + out = adaptive_instance_normalization(cF['r41'], sF['r41']) + out = self.resblock_41(out) + out = self.convblock_41(out) + + out = self.upsample(out) + out += adaptive_instance_normalization(cF['r31'], sF['r31']) + out = self.resblock_31(out) + out = self.convblock_31(out) + + out = self.upsample(out) + out += adaptive_instance_normalization(cF['r21'], sF['r21']) + out = self.convblock_21(out) + out = self.convblock_22(out) + + out = self.upsample(out) + out = self.convblock_11(out) + out = self.final_conv(out) + return out + + +vgg = nn.Sequential( + nn.Conv2D(3, 3, (1, 1)), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(3, 64, (3, 3)), + nn.ReLU(), # relu1-1 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(64, 64, (3, 3)), + nn.ReLU(), # relu1-2 + nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(64, 128, (3, 3)), + nn.ReLU(), # relu2-1 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(128, 128, (3, 3)), + nn.ReLU(), # relu2-2 + nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(128, 256, (3, 3)), + nn.ReLU(), # relu3-1 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(256, 256, (3, 3)), + nn.ReLU(), # relu3-2 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(256, 256, (3, 3)), + nn.ReLU(), # relu3-3 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(256, 256, (3, 3)), + nn.ReLU(), # relu3-4 + nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(256, 512, (3, 3)), + nn.ReLU(), # relu4-1, this is the last layer used + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu4-2 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu4-3 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu4-4 + nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu5-1 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu5-2 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU(), # relu5-3 + nn.Pad2D([1, 1, 1, 1], mode='reflect'), + nn.Conv2D(512, 512, (3, 3)), + nn.ReLU() # relu5-4 +) + + +@GENERATORS.register() +class Encoder(nn.Layer): + """Encoder of Drafting module. + Paper: + Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality + Artistic Style Transfer. + """ + def __init__(self): + super(Encoder, self).__init__() + vgg_net = vgg + weight_path = get_path_from_url( + 'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams') + vgg_net.set_dict(paddle.load(weight_path)) + self.enc_1 = nn.Sequential(*list( + vgg_net.children())[:4]) # input -> relu1_1 + self.enc_2 = nn.Sequential(*list( + vgg_net.children())[4:11]) # relu1_1 -> relu2_1 + self.enc_3 = nn.Sequential(*list( + vgg_net.children())[11:18]) # relu2_1 -> relu3_1 + self.enc_4 = nn.Sequential(*list( + vgg_net.children())[18:31]) # relu3_1 -> relu4_1 + self.enc_5 = nn.Sequential(*list( + vgg_net.children())[31:44]) # relu4_1 -> relu5_1 + + def forward(self, x): + out = {} + x = self.enc_1(x) + out['r11'] = x + x = self.enc_2(x) + out['r21'] = x + x = self.enc_3(x) + out['r31'] = x + x = self.enc_4(x) + out['r41'] = x + x = self.enc_5(x) + out['r51'] = x + return out diff --git a/ppgan/models/lapstyle_model.py b/ppgan/models/lapstyle_model.py new file mode 100644 index 0000000000000000000000000000000000000000..667bce6ae989d50168347c7c1c29f57fb70d0401 --- /dev/null +++ b/ppgan/models/lapstyle_model.py @@ -0,0 +1,118 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from .base_model import BaseModel + +from .builder import MODELS +from .generators.builder import build_generator +from .criterions import build_criterion + +from ..modules.init import init_weights + + +@MODELS.register() +class LapStyleModel(BaseModel): + def __init__(self, + generator_encode, + generator_decode, + calc_style_emd_loss=None, + calc_content_relt_loss=None, + calc_content_loss=None, + calc_style_loss=None, + content_layers=['r11', 'r21', 'r31', 'r41', 'r51'], + style_layers=['r11', 'r21', 'r31', 'r41', 'r51'], + content_weight=1.0, + style_weight=3.0): + + super(LapStyleModel, self).__init__() + + # define generators + self.nets['net_enc'] = build_generator(generator_encode) + self.nets['net_dec'] = build_generator(generator_decode) + init_weights(self.nets['net_dec']) + self.set_requires_grad([self.nets['net_enc']], False) + + # define loss functions + self.calc_style_emd_loss = build_criterion(calc_style_emd_loss) + self.calc_content_relt_loss = build_criterion(calc_content_relt_loss) + self.calc_content_loss = build_criterion(calc_content_loss) + self.calc_style_loss = build_criterion(calc_style_loss) + + self.content_layers = content_layers + self.style_layers = style_layers + self.content_weight = content_weight + self.style_weight = style_weight + + def setup_input(self, input): + self.ci = paddle.to_tensor(input['ci']) + self.visual_items['ci'] = self.ci + self.si = paddle.to_tensor(input['si']) + self.visual_items['si'] = self.si + self.image_paths = input['ci_path'] + + def forward(self): + """Run forward pass; called by both functions and .""" + self.cF = self.nets['net_enc'](self.ci) + self.sF = self.nets['net_enc'](self.si) + self.stylized = self.nets['net_dec'](self.cF, self.sF) + self.visual_items['stylized'] = self.stylized + + def backward_dnc(self): + self.tF = self.nets['net_enc'](self.stylized) + """content loss""" + self.loss_c = 0 + for layer in self.content_layers: + self.loss_c += self.calc_content_loss(self.tF[layer], + self.cF[layer], + norm=True) + self.losses['loss_c'] = self.loss_c + """style loss""" + self.loss_s = 0 + for layer in self.style_layers: + self.loss_s += self.calc_style_loss(self.tF[layer], self.sF[layer]) + self.losses['loss_s'] = self.loss_s + """IDENTITY LOSSES""" + self.Icc = self.nets['net_dec'](self.cF, self.cF) + self.l_identity1 = self.calc_content_loss(self.Icc, self.ci) + self.Fcc = self.nets['net_enc'](self.Icc) + self.l_identity2 = 0 + for layer in self.content_layers: + self.l_identity2 += self.calc_content_loss(self.Fcc[layer], + self.cF[layer]) + self.losses['l_identity1'] = self.l_identity1 + self.losses['l_identity2'] = self.l_identity2 + """relative loss""" + self.loss_style_remd = self.calc_style_emd_loss( + self.tF['r31'], self.sF['r31']) + self.calc_style_emd_loss( + self.tF['r41'], self.sF['r41']) + self.loss_content_relt = self.calc_content_relt_loss( + self.tF['r31'], self.cF['r31']) + self.calc_content_relt_loss( + self.tF['r41'], self.cF['r41']) + self.losses['loss_style_remd'] = self.loss_style_remd + self.losses['loss_content_relt'] = self.loss_content_relt + + self.loss = self.loss_c * self.content_weight + self.loss_s * self.style_weight +\ + self.l_identity1 * 50 + self.l_identity2 * 1 + self.loss_style_remd * 10 + \ + self.loss_content_relt * 16 + self.loss.backward() + + return self.loss + + def train_iter(self, optimizers=None): + """Calculate losses, gradients, and update network weights""" + self.forward() + optimizers['optimG'].clear_grad() + self.backward_dnc() + self.optimizers['optimG'].step() diff --git a/ppgan/solver/__init__.py b/ppgan/solver/__init__.py index 1b4d1fc7b586773978d80c0a397592b0ca7af5de..41df0560513a1bf2ec33b743df7c69aed61a2d85 100644 --- a/ppgan/solver/__init__.py +++ b/ppgan/solver/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay +from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay, NonLinearDecay from .optimizer import * from .builder import build_lr_scheduler from .builder import build_optimizer diff --git a/ppgan/solver/lr_scheduler.py b/ppgan/solver/lr_scheduler.py index aa7cc3de1ddeb95434eafd26b1c2bb8c7c8b8e0b..2f6fa3dfcb8691068f5a9a2e011e983a4d53ff4d 100644 --- a/ppgan/solver/lr_scheduler.py +++ b/ppgan/solver/lr_scheduler.py @@ -21,6 +21,17 @@ from .builder import LRSCHEDULERS LRSCHEDULERS.register(MultiStepDecay) +@LRSCHEDULERS.register() +class NonLinearDecay(LRScheduler): + def __init__(self, learning_rate, lr_decay, last_epoch=-1): + self.lr_decay = lr_decay + super(NonLinearDecay, self).__init__(learning_rate, last_epoch) + + def get_lr(self): + lr = self.base_lr / (1.0 + self.lr_decay * self.last_epoch) + return lr + + @LRSCHEDULERS.register() class LinearDecay(LambdaDecay): def __init__(self, learning_rate, start_epoch, decay_epochs,