未验证 提交 99267da4 编写于 作者: L LielinJiang 提交者: GitHub

Add UGATIT model (#87)

* add ugatit model
上级 37eda2ca
...@@ -85,6 +85,13 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) ...@@ -85,6 +85,13 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
</div> </div>
### 人脸动漫化
<div align='center'>
<img src='./docs/imgs/ugatit.png'width='700' height='250'/>
</div>
## 版本更新 ## 版本更新
- v0.1.0 (2020.11.02) - v0.1.0 (2020.11.02)
......
...@@ -74,6 +74,12 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -74,6 +74,12 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
</div> </div>
### Face cartoonization
<div align='center'>
<img src='./docs/imgs/ugatit.png'width='700' height='250'/>
</div>
## Changelog ## Changelog
- v0.1.0 (2020.11.02) - v0.1.0 (2020.11.02)
......
epochs: 300
output_dir: output_dir
adv_weight: 1.0
cycle_weight: 10.0
identity_weight: 10.0
cam_weight: 1000.0
model:
name: UGATITModel
generator:
name: ResnetUGATITGenerator
input_nc: 3
output_nc: 3
ngf: 64
n_blocks: 4
img_size: 256
light: True
discriminator_g:
name: UGATITDiscriminator
input_nc: 3
ndf: 64
n_layers: 7
discriminator_l:
name: UGATITDiscriminator
input_nc: 3
ndf: 64
n_layers: 5
dataset:
train:
name: UnpairedDataset
dataroot: data/selfie2anime
num_workers: 0
phase: train
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [286, 286]
interpolation: 'bilinear' #'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
test:
name: SingleDataset
dataroot: data/selfie2anime/testA
max_dataset_size: inf
direction: AtoB
input_nc: 3
output_nc: 3
serial_batches: False
transforms:
- name: Resize
size: [256, 256]
interpolation: 'bilinear' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
weight_decay: 0.0001
lr_scheduler:
name: linear
learning_rate: 0.0001
start_epoch: 150
decay_epochs: 150
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 30
...@@ -186,7 +186,6 @@ class EDVRPredictor(BasePredictor): ...@@ -186,7 +186,6 @@ class EDVRPredictor(BasePredictor):
period = cur_time - prev_time period = cur_time - prev_time
periods.append(period) periods.append(period)
# print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name)) '{}_edvr_out.mp4'.format(base_name))
......
...@@ -17,8 +17,9 @@ import time ...@@ -17,8 +17,9 @@ import time
import copy import copy
import logging import logging
import paddle import datetime
import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader from ..datasets.builder import build_dataloader
...@@ -64,6 +65,9 @@ class Trainer: ...@@ -64,6 +65,9 @@ class Trainer:
self.local_rank = ParallelEnv().local_rank self.local_rank = ParallelEnv().local_rank
# time count # time count
self.steps_per_epoch = len(self.train_dataloader)
self.total_steps = self.epochs * self.steps_per_epoch
self.time_count = {} self.time_count = {}
self.best_metric = {} self.best_metric = {}
...@@ -219,7 +223,14 @@ class Trainer: ...@@ -219,7 +223,14 @@ class Trainer:
message += 'reader_cost: %.5f sec ' % self.data_time message += 'reader_cost: %.5f sec ' % self.data_time
if hasattr(self, 'ips'): if hasattr(self, 'ips'):
message += 'ips: %.5f images/s' % self.ips message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'):
cur_step = self.steps_per_epoch * (self.current_epoch -
1) + self.batch_id
eta = self.step_time * (self.total_steps - cur_step - 1)
eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}'
# print the message # print the message
self.logger.info(message) self.logger.info(message)
......
...@@ -18,4 +18,5 @@ from .pix2pix_model import Pix2PixModel ...@@ -18,4 +18,5 @@ from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel from .srgan_model import SRGANModel
from .sr_model import SRModel from .sr_model import SRModel
from .makeup_model import MakeupModel from .makeup_model import MakeupModel
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel from .dc_gan_model import DCGANModel
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .nlayers import NLayerDiscriminator from .nlayers import NLayerDiscriminator
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator from .dcdiscriminator import DCDiscriminator
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...modules.utils import spectral_norm
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class UGATITDiscriminator(nn.Layer):
def __init__(self, input_nc, ndf=64, n_layers=5):
super(UGATITDiscriminator, self).__init__()
model = [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
spectral_norm(
nn.Conv2D(input_nc,
ndf,
kernel_size=4,
stride=2,
padding=0,
bias_attr=True)),
nn.LeakyReLU(0.2)
]
for i in range(1, n_layers - 2):
mult = 2**(i - 1)
model += [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
spectral_norm(
nn.Conv2D(ndf * mult,
ndf * mult * 2,
kernel_size=4,
stride=2,
padding=0,
bias_attr=True)),
nn.LeakyReLU(0.2)
]
mult = 2**(n_layers - 2 - 1)
model += [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
spectral_norm(
nn.Conv2D(ndf * mult,
ndf * mult * 2,
kernel_size=4,
stride=1,
padding=0,
bias_attr=True)),
nn.LeakyReLU(0.2)
]
# Class Activation Map
mult = 2**(n_layers - 2)
self.gap_fc = spectral_norm(nn.Linear(ndf * mult, 1, bias_attr=False))
self.gmp_fc = spectral_norm(nn.Linear(ndf * mult, 1, bias_attr=False))
self.conv1x1 = nn.Conv2D(ndf * mult * 2,
ndf * mult,
kernel_size=1,
stride=1,
bias_attr=True)
self.leaky_relu = nn.LeakyReLU(0.2)
self.pad = nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect")
self.conv = spectral_norm(
nn.Conv2D(ndf * mult,
1,
kernel_size=4,
stride=1,
padding=0,
bias_attr=False))
self.model = nn.Sequential(*model)
def forward(self, input):
x = self.model(input)
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1]))
gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1]))
gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
x = paddle.concat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
heatmap = paddle.sum(x, 1, keepdim=True)
x = self.pad(x)
out = self.conv(x)
return out, cam_logit, heatmap
...@@ -16,4 +16,5 @@ from .resnet import ResnetGenerator ...@@ -16,4 +16,5 @@ from .resnet import ResnetGenerator
from .unet import UnetGenerator from .unet import UnetGenerator
from .rrdb_net import RRDBNet from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention from .makeup import GeneratorPSGANAttention
from .resnet_ugatit import ResnetUGATITGenerator
from .dcgenerator import DCGenerator from .dcgenerator import DCGenerator
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...modules.norm import build_norm_layer
from ...modules.utils import spectral_norm
from .builder import GENERATORS
@GENERATORS.register()
class ResnetUGATITGenerator(nn.Layer):
def __init__(self,
input_nc,
output_nc,
ngf=64,
n_blocks=6,
img_size=256,
light=False,
norm_type='instance'):
assert (n_blocks >= 0)
super(ResnetUGATITGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.n_blocks = n_blocks
self.img_size = img_size
self.light = light
norm_layer = build_norm_layer(norm_type)
DownBlock = []
DownBlock += [
nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2D(input_nc,
ngf,
kernel_size=7,
stride=1,
padding=0,
bias_attr=False),
norm_layer(ngf),
nn.ReLU()
]
# Down-Sampling
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
DownBlock += [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
nn.Conv2D(ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=0,
bias_attr=False),
norm_layer(ngf * mult * 2),
nn.ReLU()
]
# Down-Sampling Bottleneck
mult = 2**n_downsampling
for i in range(n_blocks):
DownBlock += [
ResnetBlock(ngf * mult, use_bias=False, norm_layer=norm_layer)
]
# Class Activation Map
self.gap_fc = nn.Linear(ngf * mult, 1, bias_attr=False)
self.gmp_fc = nn.Linear(ngf * mult, 1, bias_attr=False)
self.conv1x1 = nn.Conv2D(ngf * mult * 2,
ngf * mult,
kernel_size=1,
stride=1,
bias_attr=True)
self.relu = nn.ReLU()
# Gamma, Beta block
if self.light:
FC = [
nn.Linear(ngf * mult, ngf * mult, bias_attr=False),
nn.ReLU(),
nn.Linear(ngf * mult, ngf * mult, bias_attr=False),
nn.ReLU()
]
else:
FC = [
nn.Linear(img_size // mult * img_size // mult * ngf * mult,
ngf * mult,
bias_attr=False),
nn.ReLU(),
nn.Linear(ngf * mult, ngf * mult, bias_attr=False),
nn.ReLU()
]
self.gamma = nn.Linear(ngf * mult, ngf * mult, bias_attr=False)
self.beta = nn.Linear(ngf * mult, ngf * mult, bias_attr=False)
# Up-Sampling Bottleneck
for i in range(n_blocks):
setattr(self, 'UpBlock1_' + str(i + 1),
ResnetAdaILNBlock(ngf * mult, use_bias=False))
# Up-Sampling
UpBlock2 = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
UpBlock2 += [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
nn.Conv2D(ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=1,
padding=0,
bias_attr=False),
ILN(int(ngf * mult / 2)),
nn.ReLU()
]
UpBlock2 += [
nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2D(ngf,
output_nc,
kernel_size=7,
stride=1,
padding=0,
bias_attr=False),
nn.Tanh()
]
self.DownBlock = nn.Sequential(*DownBlock)
self.FC = nn.Sequential(*FC)
self.UpBlock2 = nn.Sequential(*UpBlock2)
def forward(self, input):
x = self.DownBlock(input)
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1]))
gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1]))
gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
x = paddle.concat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = paddle.sum(x, axis=1, keepdim=True)
if self.light:
x_ = F.adaptive_avg_pool2d(x, 1)
x_ = self.FC(x_.reshape([x_.shape[0], -1]))
else:
x_ = self.FC(x.reshape([x.shape[0], -1]))
gamma, beta = self.gamma(x_), self.beta(x_)
for i in range(self.n_blocks):
x = getattr(self, 'UpBlock1_' + str(i + 1))(x, gamma, beta)
out = self.UpBlock2(x)
return out, cam_logit, heatmap
class ResnetBlock(nn.Layer):
def __init__(self, dim, use_bias, norm_layer):
super(ResnetBlock, self).__init__()
conv_block = []
conv_block += [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
nn.Conv2D(dim,
dim,
kernel_size=3,
stride=1,
padding=0,
bias_attr=use_bias),
norm_layer(dim),
nn.ReLU()
]
conv_block += [
nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect"),
nn.Conv2D(dim,
dim,
kernel_size=3,
stride=1,
padding=0,
bias_attr=use_bias),
norm_layer(dim)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class ResnetAdaILNBlock(nn.Layer):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.pad1 = nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect")
self.conv1 = nn.Conv2D(dim,
dim,
kernel_size=3,
stride=1,
padding=0,
bias_attr=use_bias)
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU()
self.pad2 = nn.Pad2D(padding=[1, 1, 1, 1], mode="reflect")
self.conv2 = nn.Conv2D(dim,
dim,
kernel_size=3,
stride=1,
padding=0,
bias_attr=use_bias)
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.pad1(x)
out = self.conv1(out)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.pad2(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
class AdaILN(nn.Layer):
def __init__(self, num_features, eps=1e-5):
super(AdaILN, self).__init__()
self.eps = eps
shape = (1, num_features, 1, 1)
self.rho = self.create_parameter(shape)
self.rho.set_value(paddle.full(shape, 0.9))
def forward(self, input, gamma, beta):
in_mean, in_var = paddle.mean(input, [2, 3],
keepdim=True), paddle.var(input, [2, 3],
keepdim=True)
out_in = (input - in_mean) / paddle.sqrt(in_var + self.eps)
ln_mean, ln_var = paddle.mean(input, [1, 2, 3],
keepdim=True), paddle.var(input,
[1, 2, 3],
keepdim=True)
out_ln = (input - ln_mean) / paddle.sqrt(ln_var + self.eps)
out = self.rho.expand([input.shape[0], -1, -1, -1]) * out_in + (
1 - self.rho.expand([input.shape[0], -1, -1, -1])) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(
2).unsqueeze(3)
return out
class ILN(nn.Layer):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
shape = (1, num_features, 1, 1)
self.rho = self.create_parameter(shape)
self.gamma = self.create_parameter(shape)
self.beta = self.create_parameter(shape)
self.rho.set_value(paddle.full(shape, 0.0))
self.gamma.set_value(paddle.full(shape, 1.0))
self.beta.set_value(paddle.full(shape, 0.0))
def forward(self, input):
in_mean, in_var = paddle.mean(input, [2, 3],
keepdim=True), paddle.var(input, [2, 3],
keepdim=True)
out_in = (input - in_mean) / paddle.sqrt(in_var + self.eps)
ln_mean, ln_var = paddle.mean(input, [1, 2, 3],
keepdim=True), paddle.var(input,
[1, 2, 3],
keepdim=True)
out_ln = (input - ln_mean) / paddle.sqrt(ln_var + self.eps)
out = self.rho.expand([input.shape[0], -1, -1, -1]) * out_in + (
1 - self.rho.expand([input.shape[0], -1, -1, -1])) * out_ln
out = out * self.gamma.expand([input.shape[0], -1, -1, -1
]) + self.beta.expand(
[input.shape[0], -1, -1, -1])
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
from ..solver import build_optimizer
from ..modules.nn import RhoClipper
from ..modules.init import init_weights
from ..utils.image_pool import ImagePool
@MODELS.register()
class UGATITModel(BaseModel):
"""
This class implements the UGATIT model, for learning image-to-image translation without paired data.
UGATIT paper: https://arxiv.org/pdf/1907.10830.pdf
"""
def __init__(self, cfg):
"""Initialize the CycleGAN class.
Parameters:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(UGATITModel, self).__init__(cfg)
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
self.nets['genA2B'] = build_generator(cfg.model.generator)
self.nets['genB2A'] = build_generator(cfg.model.generator)
init_weights(self.nets['genA2B'])
init_weights(self.nets['genB2A'])
if self.is_train:
# define discriminators
self.nets['disGA'] = build_discriminator(cfg.model.discriminator_g)
self.nets['disGB'] = build_discriminator(cfg.model.discriminator_g)
self.nets['disLA'] = build_discriminator(cfg.model.discriminator_l)
self.nets['disLB'] = build_discriminator(cfg.model.discriminator_l)
init_weights(self.nets['disGA'])
init_weights(self.nets['disGB'])
init_weights(self.nets['disLA'])
init_weights(self.nets['disLB'])
if self.is_train:
# define loss functions
self.BCE_loss = nn.BCEWithLogitsLoss()
self.L1_loss = nn.L1Loss()
self.MSE_loss = nn.MSELoss()
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['genA2B'].parameters() +
self.nets['genB2A'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['disGA'].parameters() +
self.nets['disGB'].parameters() +
self.nets['disLA'].parameters() +
self.nets['disLB'].parameters())
self.Rho_clipper = RhoClipper(0, 1)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
mode = 'train' if self.is_train else 'test'
AtoB = self.cfg.dataset[mode].direction == 'AtoB'
if AtoB:
if 'A' in input:
self.real_A = paddle.to_tensor(input['A'])
if 'B' in input:
self.real_B = paddle.to_tensor(input['B'])
else:
if 'B' in input:
self.real_A = paddle.to_tensor(input['B'])
if 'A' in input:
self.real_B = paddle.to_tensor(input['A'])
if 'A_paths' in input:
self.image_paths = input['A_paths']
elif 'B_paths' in input:
self.image_paths = input['B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'):
self.fake_A2B, _, _ = self.nets['genA2B'](self.real_A)
# visual
self.visual_items['real_A'] = self.real_A
self.visual_items['fake_A2B'] = self.fake_A2B
if hasattr(self, 'real_B'):
self.fake_B2A, _, _ = self.nets['genB2A'](self.real_B)
# visual
self.visual_items['real_B'] = self.real_B
self.visual_items['fake_B2A'] = self.fake_B2A
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
self.nets['genA2B'].eval()
self.nets['genB2A'].eval()
with paddle.no_grad():
self.forward()
self.compute_visuals()
self.nets['genA2B'].train()
self.nets['genB2A'].train()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
def _criterion(loss_func, logit, is_real):
if is_real:
target = paddle.ones_like(logit)
else:
target = paddle.zeros_like(logit)
return loss_func(logit, target)
# forward
# compute fake images and reconstruction images.
self.forward()
# update D
self.optimizers['optimizer_D'].clear_grad()
real_GA_logit, real_GA_cam_logit, _ = self.nets['disGA'](self.real_A)
real_LA_logit, real_LA_cam_logit, _ = self.nets['disLA'](self.real_A)
real_GB_logit, real_GB_cam_logit, _ = self.nets['disGB'](self.real_B)
real_LB_logit, real_LB_cam_logit, _ = self.nets['disLB'](self.real_B)
fake_GA_logit, fake_GA_cam_logit, _ = self.nets['disGA'](self.fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.nets['disLA'](self.fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.nets['disGB'](self.fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.nets['disLB'](self.fake_A2B)
D_ad_loss_GA = _criterion(self.MSE_loss,
real_GA_logit, True) + _criterion(
self.MSE_loss, fake_GA_logit, False)
D_ad_cam_loss_GA = _criterion(
self.MSE_loss, real_GA_cam_logit, True) + _criterion(
self.MSE_loss, fake_GA_cam_logit, False)
D_ad_loss_LA = _criterion(self.MSE_loss,
real_LA_logit, True) + _criterion(
self.MSE_loss, fake_LA_logit, False)
D_ad_cam_loss_LA = _criterion(
self.MSE_loss, real_LA_cam_logit, True) + _criterion(
self.MSE_loss, fake_LA_cam_logit, False)
D_ad_loss_GB = _criterion(self.MSE_loss,
real_GB_logit, True) + _criterion(
self.MSE_loss, fake_GB_logit, False)
D_ad_cam_loss_GB = _criterion(
self.MSE_loss, real_GB_cam_logit, True) + _criterion(
self.MSE_loss, fake_GB_cam_logit, False)
D_ad_loss_LB = _criterion(self.MSE_loss,
real_LB_logit, True) + _criterion(
self.MSE_loss, fake_LB_logit, False)
D_ad_cam_loss_LB = _criterion(
self.MSE_loss, real_LB_cam_logit, True) + _criterion(
self.MSE_loss, fake_LB_cam_logit, False)
D_loss_A = self.cfg.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA +
D_ad_loss_LA + D_ad_cam_loss_LA)
D_loss_B = self.cfg.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB +
D_ad_loss_LB + D_ad_cam_loss_LB)
Discriminator_loss = D_loss_A + D_loss_B
Discriminator_loss.backward()
self.optimizers['optimizer_D'].step()
# update G
self.optimizers['optimizer_G'].clear_grad()
fake_A2B, fake_A2B_cam_logit, _ = self.nets['genA2B'](self.real_A)
fake_B2A, fake_B2A_cam_logit, _ = self.nets['genB2A'](self.real_B)
fake_A2B2A, _, _ = self.nets['genB2A'](fake_A2B)
fake_B2A2B, _, _ = self.nets['genA2B'](fake_B2A)
fake_A2A, fake_A2A_cam_logit, _ = self.nets['genB2A'](self.real_A)
fake_B2B, fake_B2B_cam_logit, _ = self.nets['genA2B'](self.real_B)
fake_GA_logit, fake_GA_cam_logit, _ = self.nets['disGA'](fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.nets['disLA'](fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.nets['disGB'](fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.nets['disLB'](fake_A2B)
G_ad_loss_GA = _criterion(self.MSE_loss, fake_GA_logit, True)
G_ad_cam_loss_GA = _criterion(self.MSE_loss, fake_GA_cam_logit, True)
G_ad_loss_LA = _criterion(self.MSE_loss, fake_LA_logit, True)
G_ad_cam_loss_LA = _criterion(self.MSE_loss, fake_LA_cam_logit, True)
G_ad_loss_GB = _criterion(self.MSE_loss, fake_GB_logit, True)
G_ad_cam_loss_GB = _criterion(self.MSE_loss, fake_GB_cam_logit, True)
G_ad_loss_LB = _criterion(self.MSE_loss, fake_LB_logit, True)
G_ad_cam_loss_LB = _criterion(self.MSE_loss, fake_LB_cam_logit, True)
G_recon_loss_A = self.L1_loss(fake_A2B2A, self.real_A)
G_recon_loss_B = self.L1_loss(fake_B2A2B, self.real_B)
G_identity_loss_A = self.L1_loss(fake_A2A, self.real_A)
G_identity_loss_B = self.L1_loss(fake_B2B, self.real_B)
G_cam_loss_A = _criterion(self.BCE_loss,
fake_B2A_cam_logit, True) + _criterion(
self.BCE_loss, fake_A2A_cam_logit, False)
G_cam_loss_B = _criterion(self.BCE_loss,
fake_A2B_cam_logit, True) + _criterion(
self.BCE_loss, fake_B2B_cam_logit, False)
G_loss_A = self.cfg.adv_weight * (
G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA
) + self.cfg.cycle_weight * G_recon_loss_A + self.cfg.identity_weight * G_identity_loss_A + self.cfg.cam_weight * G_cam_loss_A
G_loss_B = self.cfg.adv_weight * (
G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB
) + self.cfg.cycle_weight * G_recon_loss_B + self.cfg.identity_weight * G_identity_loss_B + self.cfg.cam_weight * G_cam_loss_B
Generator_loss = G_loss_A + G_loss_B
Generator_loss.backward()
self.optimizers['optimizer_G'].step()
# clip parameter of AdaILN and ILN, applied after optimizer step
self.nets['genA2B'].apply(self.Rho_clipper)
self.nets['genB2A'].apply(self.Rho_clipper)
self.losses['discriminator_loss'] = Discriminator_loss
self.losses['generator_loss'] = Generator_loss
...@@ -65,3 +65,17 @@ class Spectralnorm(paddle.nn.Layer): ...@@ -65,3 +65,17 @@ class Spectralnorm(paddle.nn.Layer):
self.layer.weight = weight self.layer.weight = weight
out = self.layer(x) out = self.layer(x)
return out return out
class RhoClipper(object):
def __init__(self, min, max):
self.clip_min = min
self.clip_max = max
assert min < max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho
w = w.clip(self.clip_min, self.clip_max)
module.rho.set_value(w)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .init import normal_
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError(
'Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# transpose dim to front
weight_mat = weight_mat.transpose([
self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim]
])
height = weight_mat.shape[0]
return weight_mat.reshape([height, -1])
def compute_weight(self, layer, do_power_iteration):
weight = getattr(layer, self.name + '_orig')
u = getattr(layer, self.name + '_u')
v = getattr(layer, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with paddle.no_grad():
for _ in range(self.n_power_iterations):
v.set_value(
F.normalize(
paddle.matmul(weight_mat,
u,
transpose_x=True,
transpose_y=False),
axis=0,
epsilon=self.eps,
))
u.set_value(
F.normalize(
paddle.matmul(weight_mat, v),
axis=0,
epsilon=self.eps,
))
if self.n_power_iterations > 0:
u = u.clone()
v = v.clone()
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, layer):
with paddle.no_grad():
weight = self.compute_weight(layer, do_power_iteration=False)
delattr(layer, self.name)
delattr(layer, self.name + '_u')
delattr(layer, self.name + '_v')
delattr(layer, self.name + '_orig')
layer.add_parameter(self.name, weight.detach())
def __call__(self, layer, inputs):
setattr(layer, self.name,
self.compute_weight(layer, do_power_iteration=layer.training))
@staticmethod
def apply(layer, name, n_power_iterations, dim, eps):
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError("Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = layer._parameters[name]
with paddle.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.shape
# randomly initialize u and v
u = layer.create_parameter([h])
u = normal_(u, 0., 1.)
v = layer.create_parameter([w])
v = normal_(v, 0., 1.)
u = F.normalize(u, axis=0, epsilon=fn.eps)
v = F.normalize(v, axis=0, epsilon=fn.eps)
# delete fn.name form parameters, otherwise you can not set attribute
del layer._parameters[fn.name]
layer.add_parameter(fn.name + "_orig", weight)
# still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an Parameter and
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
# attribute.
setattr(layer, fn.name, weight * 1.0)
layer.register_buffer(fn.name + "_u", u)
layer.register_buffer(fn.name + "_v", v)
layer.register_forward_pre_hook(fn)
return fn
def spectral_norm(layer,
name='weight',
n_power_iterations=1,
eps=1e-12,
dim=None):
if dim is None:
if isinstance(layer, (nn.Conv1DTranspose, nn.Conv2DTranspose,
nn.Conv3DTranspose, nn.Linear)):
dim = 1
else:
dim = 0
SpectralNorm.apply(layer, name, n_power_iterations, dim, eps)
return layer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册