未验证 提交 0ded43f7 编写于 作者: W wangna11BD 提交者: GitHub

Add LapStyle Model from vis (#307)

* add LapStyle Model
上级 c12e50aa
total_iters: 30000
output_dir: output_dir
checkpoints_dir: checkpoints
min_max:
(0., 1.)
model:
name: LapStyleModel
generator_encode:
name: Encoder
generator_decode:
name: DecoderNet
calc_style_emd_loss:
name: CalcStyleEmdLoss
calc_content_relt_loss:
name: CalcContentReltLoss
calc_content_loss:
name: CalcContentLoss
calc_style_loss:
name: CalcStyleLoss
content_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
style_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
content_weight: 1.0
style_weight: 3.0
dataset:
train:
name: LapStyleDataset
content_root: data/coco/train2017/
style_root: data/starrynew.png
load_size: 136
crop_size: 128
num_workers: 16
batch_size: 5
test:
name: LapStyleDataset
content_root: data/coco/test2017/
style_root: data/starrynew.png
load_size: 136
crop_size: 128
num_workers: 0
batch_size: 1
lr_scheduler:
name: NonLinearDecay
learning_rate: 1e-4
lr_decay: 5e-5
optimizer:
optimG:
name: Adam
net_names:
- net_dec
beta1: 0.9
beta2: 0.999
validate:
interval: 5000
save_img: false
log_config:
interval: 10
visiual_interval: 5000
snapshot_config:
interval: 5000
......@@ -23,3 +23,4 @@ from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
from .lapstyle_dataset import LapStyleDataset
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import numpy as np
from PIL import Image
import paddle
import paddle.vision.transforms as T
from paddle.io import Dataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
def data_transform(crop_size):
transform_list = [T.RandomCrop(crop_size)]
return T.Compose(transform_list)
@DATASETS.register()
class LapStyleDataset(Dataset):
"""
coco2017 dataset for LapStyle model
"""
def __init__(self, content_root, style_root, load_size, crop_size):
super(LapStyleDataset, self).__init__()
self.content_root = content_root
self.paths = os.listdir(self.content_root)
self.style_root = style_root
self.load_size = load_size
self.crop_size = crop_size
self.transform = data_transform(self.crop_size)
def __getitem__(self, index):
"""Get training sample
return:
ci: content image with shape [C,W,H],
si: style image with shape [C,W,H],
ci_path: str
"""
path = self.paths[index]
content_img = Image.open(os.path.join(self.content_root,
path)).convert('RGB')
content_img = content_img.resize((self.load_size, self.load_size),
Image.BILINEAR)
content_img = np.array(content_img)
style_img = Image.open(self.style_root).convert('RGB')
style_img = style_img.resize((self.load_size, self.load_size),
Image.BILINEAR)
style_img = np.array(style_img)
content_img = self.transform(content_img)
style_img = self.transform(style_img)
content_img = self.img(content_img)
style_img = self.img(style_img)
return {'ci': content_img, 'si': style_img, 'ci_path': path}
def img(self, img):
"""make image with [0,255] and HWC to [0,1] and CHW
return:
img: image with shape [3,W,H] and value [0, 1].
"""
# [0,255] to [0,1]
img = img.astype(np.float32) / 255.
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
# HWC to CHW
img = np.transpose(img, (2, 0, 1)).astype('float32')
return img
def __len__(self):
return len(self.paths)
def name(self):
return 'LapStyleDataset'
......@@ -347,7 +347,10 @@ class Trainer:
dataformats="HWC" if image_num == 1 else "NCHW")
else:
if self.cfg.is_train:
msg = 'epoch%.3d_' % self.current_epoch
if self.by_epoch:
msg = 'epoch%.3d_' % self.current_epoch
else:
msg = 'iter%.3d_' % self.current_iter
else:
msg = ''
makedirs(os.path.join(self.output_dir, results_dir))
......
......@@ -29,3 +29,4 @@ from .wav2lip_hq_model import Wav2LipModelHq
from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel
from .lapstyle_model import LapStyleModel
from .gan_loss import GANLoss
from .perceptual_loss import PerceptualLoss
from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss
from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \
CalcStyleEmdLoss, CalcContentReltLoss, \
CalcContentLoss, CalcStyleLoss
from .builder import build_criterion
......@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
from ..generators.generater_lapstyle import calc_mean_std, mean_variance_norm
import paddle
import paddle.nn as nn
......@@ -127,3 +128,109 @@ class BCEWithLogitsLoss():
weights. Default: None.
"""
return self.loss_weight * self._bce_loss(pred, target)
def calc_emd_loss(pred, target):
"""calculate emd loss.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
b, _, h, w = pred.shape
pred = pred.reshape([b, -1, w * h])
pred_norm = paddle.sqrt((pred**2).sum(1).reshape([b, -1, 1]))
pred = pred.transpose([0, 2, 1])
target_t = target.reshape([b, -1, w * h])
target_norm = paddle.sqrt((target**2).sum(1).reshape([b, 1, -1]))
similarity = paddle.bmm(pred, target_t) / pred_norm / target_norm
dist = 1. - similarity
return dist
@CRITERIONS.register()
class CalcStyleEmdLoss():
"""Calc Style Emd Loss.
"""
def __init__(self):
super(CalcStyleEmdLoss, self).__init__()
def __call__(self, pred, target):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
CX_M = calc_emd_loss(pred, target)
m1 = CX_M.min(2)
m2 = CX_M.min(1)
m = paddle.concat([m1.mean(), m2.mean()])
loss_remd = paddle.max(m)
return loss_remd
@CRITERIONS.register()
class CalcContentReltLoss():
"""Calc Content Relt Loss.
"""
def __init__(self):
super(CalcContentReltLoss, self).__init__()
def __call__(self, pred, target):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
dM = 1.
Mx = calc_emd_loss(pred, pred)
Mx = Mx / Mx.sum(1, keepdim=True)
My = calc_emd_loss(target, target)
My = My / My.sum(1, keepdim=True)
loss_content = paddle.abs(
dM * (Mx - My)).mean() * pred.shape[2] * pred.shape[3]
return loss_content
@CRITERIONS.register()
class CalcContentLoss():
"""Calc Content Loss.
"""
def __init__(self):
self.mse_loss = nn.MSELoss()
def __call__(self, pred, target, norm=False):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
norm(Bool): whether use mean_variance_norm for pred and target
"""
if (norm == False):
return self.mse_loss(pred, target)
else:
return self.mse_loss(mean_variance_norm(pred),
mean_variance_norm(target))
@CRITERIONS.register()
class CalcStyleLoss():
"""Calc Style Loss.
"""
def __init__(self):
self.mse_loss = nn.MSELoss()
def __call__(self, pred, target):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
pred_mean, pred_std = calc_mean_std(pred)
target_mean, target_std = calc_mean_std(target)
return self.mse_loss(pred_mean, target_mean) + self.mse_loss(
pred_std, target_std)
......@@ -22,3 +22,4 @@ from .syncnet import SyncNetColor
from .wav2lip_disc_qual import Wav2LipDiscQual
from .discriminator_starganv2 import StarGANv2Discriminator
from .discriminator_firstorder import FirstOrderDiscriminator
from .discriminator_lapstyle import LapStyleDiscriminator
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class LapStyleDiscriminator(nn.Layer):
def __init__(self):
super(LapStyleDiscriminator, self).__init__()
num_layer = 3
num_channel = 32
self.head = nn.Sequential(
('conv',
nn.Conv2D(3, num_channel, kernel_size=3, stride=1, padding=1)),
('norm', nn.BatchNorm2D(num_channel)),
('LeakyRelu', nn.LeakyReLU(0.2)))
self.body = nn.Sequential()
for i in range(num_layer - 2):
self.body.add_sublayer(
'conv%d' % (i + 1),
nn.Conv2D(num_channel,
num_channel,
kernel_size=3,
stride=1,
padding=1))
self.body.add_sublayer('norm%d' % (i + 1),
nn.BatchNorm2D(num_channel))
self.body.add_sublayer('LeakyRelu%d' % (i + 1), nn.LeakyReLU(0.2))
self.tail = nn.Conv2D(num_channel,
1,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = self.head(x)
x = self.body(x)
x = self.tail(x)
return x
......@@ -29,3 +29,4 @@ from .drn import DRNGenerator
from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN
from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
from ...utils.download import get_path_from_url
from .builder import GENERATORS
def calc_mean_std(feat, eps=1e-5):
"""calculate mean and standard deviation.
Args:
feat (Tensor): Tensor with shape (N, C, H, W).
eps (float): Default: 1e-5.
Return:
mean and std of feat
shape: [N, C, 1, 1]
"""
size = feat.shape
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.reshape([N, C, -1])
feat_var = paddle.var(feat_var, axis=2) + eps
feat_std = paddle.sqrt(feat_var)
feat_std = feat_std.reshape([N, C, 1, 1])
feat_mean = feat.reshape([N, C, -1])
feat_mean = paddle.mean(feat_mean, axis=2)
feat_mean = feat_mean.reshape([N, C, 1, 1])
return feat_mean, feat_std
def mean_variance_norm(feat):
"""mean_variance_norm.
Args:
feat (Tensor): Tensor with shape (N, C, H, W).
Return:
Normalized feat with shape (N, C, H, W)
"""
size = feat.shape
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def adaptive_instance_normalization(content_feat, style_feat):
"""adaptive_instance_normalization.
Args:
content_feat (Tensor): Tensor with shape (N, C, H, W).
style_feat (Tensor): Tensor with shape (N, C, H, W).
Return:
Normalized content_feat with shape (N, C, H, W)
"""
assert (content_feat.shape[:2] == style_feat.shape[:2])
size = content_feat.shape
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat -
content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class ResnetBlock(nn.Layer):
"""Residual block.
It has a style of:
---Pad-Conv-ReLU-Pad-Conv-+-
|________________________|
Args:
dim (int): Channel number of intermediate features.
"""
def __init__(self, dim):
super(ResnetBlock, self).__init__()
self.conv_block = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(dim, dim, (3, 3)), nn.ReLU(),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(dim, dim, (3, 3)))
def forward(self, x):
out = x + self.conv_block(x)
return out
class ConvBlock(nn.Layer):
"""convolution block.
It has a style of:
---Pad-Conv-ReLU---
Args:
dim1 (int): Channel number of input features.
dim2 (int): Channel number of output features.
"""
def __init__(self, dim1, dim2):
super(ConvBlock, self).__init__()
self.conv_block = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(dim1, dim2, (3, 3)),
nn.ReLU())
def forward(self, x):
out = self.conv_block(x)
return out
@GENERATORS.register()
class DecoderNet(nn.Layer):
"""Decoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def __init__(self):
super(DecoderNet, self).__init__()
self.resblock_41 = ResnetBlock(512)
self.convblock_41 = ConvBlock(512, 256)
self.resblock_31 = ResnetBlock(256)
self.convblock_31 = ConvBlock(256, 128)
self.convblock_21 = ConvBlock(128, 128)
self.convblock_22 = ConvBlock(128, 64)
self.convblock_11 = ConvBlock(64, 64)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.final_conv = nn.Sequential(nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 3, (3, 3)))
def forward(self, cF, sF):
out = adaptive_instance_normalization(cF['r41'], sF['r41'])
out = self.resblock_41(out)
out = self.convblock_41(out)
out = self.upsample(out)
out += adaptive_instance_normalization(cF['r31'], sF['r31'])
out = self.resblock_31(out)
out = self.convblock_31(out)
out = self.upsample(out)
out += adaptive_instance_normalization(cF['r21'], sF['r21'])
out = self.convblock_21(out)
out = self.convblock_22(out)
out = self.upsample(out)
out = self.convblock_11(out)
out = self.final_conv(out)
return out
vgg = nn.Sequential(
nn.Conv2D(3, 3, (1, 1)),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2D((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
@GENERATORS.register()
class Encoder(nn.Layer):
"""Encoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def __init__(self):
super(Encoder, self).__init__()
vgg_net = vgg
weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams')
vgg_net.set_dict(paddle.load(weight_path))
self.enc_1 = nn.Sequential(*list(
vgg_net.children())[:4]) # input -> relu1_1
self.enc_2 = nn.Sequential(*list(
vgg_net.children())[4:11]) # relu1_1 -> relu2_1
self.enc_3 = nn.Sequential(*list(
vgg_net.children())[11:18]) # relu2_1 -> relu3_1
self.enc_4 = nn.Sequential(*list(
vgg_net.children())[18:31]) # relu3_1 -> relu4_1
self.enc_5 = nn.Sequential(*list(
vgg_net.children())[31:44]) # relu4_1 -> relu5_1
def forward(self, x):
out = {}
x = self.enc_1(x)
out['r11'] = x
x = self.enc_2(x)
out['r21'] = x
x = self.enc_3(x)
out['r31'] = x
x = self.enc_4(x)
out['r41'] = x
x = self.enc_5(x)
out['r51'] = x
return out
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .criterions import build_criterion
from ..modules.init import init_weights
@MODELS.register()
class LapStyleModel(BaseModel):
def __init__(self,
generator_encode,
generator_decode,
calc_style_emd_loss=None,
calc_content_relt_loss=None,
calc_content_loss=None,
calc_style_loss=None,
content_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
style_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
content_weight=1.0,
style_weight=3.0):
super(LapStyleModel, self).__init__()
# define generators
self.nets['net_enc'] = build_generator(generator_encode)
self.nets['net_dec'] = build_generator(generator_decode)
init_weights(self.nets['net_dec'])
self.set_requires_grad([self.nets['net_enc']], False)
# define loss functions
self.calc_style_emd_loss = build_criterion(calc_style_emd_loss)
self.calc_content_relt_loss = build_criterion(calc_content_relt_loss)
self.calc_content_loss = build_criterion(calc_content_loss)
self.calc_style_loss = build_criterion(calc_style_loss)
self.content_layers = content_layers
self.style_layers = style_layers
self.content_weight = content_weight
self.style_weight = style_weight
def setup_input(self, input):
self.ci = paddle.to_tensor(input['ci'])
self.visual_items['ci'] = self.ci
self.si = paddle.to_tensor(input['si'])
self.visual_items['si'] = self.si
self.image_paths = input['ci_path']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.cF = self.nets['net_enc'](self.ci)
self.sF = self.nets['net_enc'](self.si)
self.stylized = self.nets['net_dec'](self.cF, self.sF)
self.visual_items['stylized'] = self.stylized
def backward_dnc(self):
self.tF = self.nets['net_enc'](self.stylized)
"""content loss"""
self.loss_c = 0
for layer in self.content_layers:
self.loss_c += self.calc_content_loss(self.tF[layer],
self.cF[layer],
norm=True)
self.losses['loss_c'] = self.loss_c
"""style loss"""
self.loss_s = 0
for layer in self.style_layers:
self.loss_s += self.calc_style_loss(self.tF[layer], self.sF[layer])
self.losses['loss_s'] = self.loss_s
"""IDENTITY LOSSES"""
self.Icc = self.nets['net_dec'](self.cF, self.cF)
self.l_identity1 = self.calc_content_loss(self.Icc, self.ci)
self.Fcc = self.nets['net_enc'](self.Icc)
self.l_identity2 = 0
for layer in self.content_layers:
self.l_identity2 += self.calc_content_loss(self.Fcc[layer],
self.cF[layer])
self.losses['l_identity1'] = self.l_identity1
self.losses['l_identity2'] = self.l_identity2
"""relative loss"""
self.loss_style_remd = self.calc_style_emd_loss(
self.tF['r31'], self.sF['r31']) + self.calc_style_emd_loss(
self.tF['r41'], self.sF['r41'])
self.loss_content_relt = self.calc_content_relt_loss(
self.tF['r31'], self.cF['r31']) + self.calc_content_relt_loss(
self.tF['r41'], self.cF['r41'])
self.losses['loss_style_remd'] = self.loss_style_remd
self.losses['loss_content_relt'] = self.loss_content_relt
self.loss = self.loss_c * self.content_weight + self.loss_s * self.style_weight +\
self.l_identity1 * 50 + self.l_identity2 * 1 + self.loss_style_remd * 10 + \
self.loss_content_relt * 16
self.loss.backward()
return self.loss
def train_iter(self, optimizers=None):
"""Calculate losses, gradients, and update network weights"""
self.forward()
optimizers['optimG'].clear_grad()
self.backward_dnc()
self.optimizers['optimG'].step()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay
from .lr_scheduler import CosineAnnealingRestartLR, LinearDecay, NonLinearDecay
from .optimizer import *
from .builder import build_lr_scheduler
from .builder import build_optimizer
......@@ -21,6 +21,17 @@ from .builder import LRSCHEDULERS
LRSCHEDULERS.register(MultiStepDecay)
@LRSCHEDULERS.register()
class NonLinearDecay(LRScheduler):
def __init__(self, learning_rate, lr_decay, last_epoch=-1):
self.lr_decay = lr_decay
super(NonLinearDecay, self).__init__(learning_rate, last_epoch)
def get_lr(self):
lr = self.base_lr / (1.0 + self.lr_decay * self.last_epoch)
return lr
@LRSCHEDULERS.register()
class LinearDecay(LambdaDecay):
def __init__(self, learning_rate, start_epoch, decay_epochs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册