未验证 提交 cf153154 编写于 作者: L LielinJiang 提交者: GitHub

fix animegan (#158)

* adapt animegan

* clean code
Co-authored-by: Nqingqing01 <dangqingqing@baidu.com>
上级 6cee870f
epochs: 30 epochs: 30
output_dir: output_dir output_dir: output_dir
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
con_weight: 1.5
sty_weight: 2.5
color_weight: 10.
tv_weight: 1.
model: model:
name: AnimeGANV2Model name: AnimeGANV2Model
...@@ -14,7 +7,16 @@ model: ...@@ -14,7 +7,16 @@ model:
name: AnimeGenerator name: AnimeGenerator
discriminator: discriminator:
name: AnimeDiscriminator name: AnimeDiscriminator
gan_mode: lsgan gan_criterion:
name: GANLoss
gan_mode: lsgan
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
con_weight: 1.5
sty_weight: 2.5
color_weight: 10.
tv_weight: 1.
dataset: dataset:
train: train:
...@@ -23,8 +25,6 @@ dataset: ...@@ -23,8 +25,6 @@ dataset:
batch_size: 4 batch_size: 4
dataroot: data/animedataset dataroot: data/animedataset
style: Hayao style: Hayao
phase: train
direction: AtoB
transform_real: transform_real:
- name: Transpose - name: Transpose
- name: Normalize - name: Normalize
...@@ -63,15 +63,25 @@ dataset: ...@@ -63,15 +63,25 @@ dataset:
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.00002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
epochs: 2 epochs: 2
output_dir: output_dir output_dir: output_dir
con_weight: 1
pretrain_ckpt: null
model: model:
name: AnimeGANV2PreTrainModel name: AnimeGANV2PreTrainModel
...@@ -9,7 +7,11 @@ model: ...@@ -9,7 +7,11 @@ model:
name: AnimeGenerator name: AnimeGenerator
discriminator: discriminator:
name: AnimeDiscriminator name: AnimeDiscriminator
gan_mode: lsgan gan_criterion:
name: GANLoss
gan_mode: lsgan
con_weight: 1
pretrain_ckpt: null
dataset: dataset:
train: train:
...@@ -18,8 +20,6 @@ dataset: ...@@ -18,8 +20,6 @@ dataset:
batch_size: 4 batch_size: 4
dataroot: data/animedataset dataroot: data/animedataset
style: Hayao style: Hayao
phase: train
direction: AtoB
transform_real: transform_real:
- name: Transpose - name: Transpose
- name: Normalize - name: Normalize
...@@ -57,15 +57,25 @@ dataset: ...@@ -57,15 +57,25 @@ dataset:
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler: lr_scheduler:
name: linear name: LinearDecay
learning_rate: 0.0002 learning_rate: 0.0002
start_epoch: 100 start_epoch: 100
decay_epochs: 100 decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config: log_config:
interval: 100 interval: 100
......
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
#limitations under the License. #limitations under the License.
import cv2 import cv2
import numpy as np
import os.path import os.path
import numpy as np
import paddle
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .image_folder import ImageFolder from .image_folder import ImageFolder
...@@ -23,21 +24,27 @@ from .transforms.builder import build_transforms ...@@ -23,21 +24,27 @@ from .transforms.builder import build_transforms
@DATASETS.register() @DATASETS.register()
class AnimeGANV2Dataset(BaseDataset): class AnimeGANV2Dataset(paddle.io.Dataset):
""" """
""" """
def __init__(self, cfg): def __init__(self,
dataroot,
style,
transform_real=None,
transform_anime=None,
transform_gray=None):
"""Initialize this dataset class. """Initialize this dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags cfg (dict) -- stores all the experiment flags
""" """
BaseDataset.__init__(self, cfg) # self.cfg = cfg
self.style = cfg.style self.root = dataroot
self.style = style
self.transform_real = build_transforms(self.cfg.transform_real) self.transform_real = build_transforms(transform_real)
self.transform_anime = build_transforms(self.cfg.transform_anime) self.transform_anime = build_transforms(transform_anime)
self.transform_gray = build_transforms(self.cfg.transform_gray) self.transform_gray = build_transforms(transform_gray)
self.real_root = os.path.join(self.root, 'train_photo') self.real_root = os.path.join(self.root, 'train_photo')
self.anime_root = os.path.join(self.root, f'{self.style}', 'style') self.anime_root = os.path.join(self.root, f'{self.style}', 'style')
......
...@@ -13,62 +13,66 @@ ...@@ -13,62 +13,66 @@
#limitations under the License. #limitations under the License.
import paddle import paddle
from paddle import nn import paddle.nn as nn
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from .criterions.gan_loss import GANLoss from .criterions import build_criterion
from ..modules.caffevgg import CaffeVGG19 from ..modules.caffevgg import CaffeVGG19
from ..solver import build_optimizer
from ..modules.init import init_weights from ..modules.init import init_weights
from ..utils.filesystem import load from ..utils.filesystem import load
@MODELS.register() @MODELS.register()
class AnimeGANV2Model(BaseModel): class AnimeGANV2Model(BaseModel):
def __init__(self, cfg): def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
pretrain_ckpt=None,
g_adv_weight=300.,
d_adv_weight=300.,
con_weight=1.5,
sty_weight=2.5,
color_weight=10.,
tv_weight=1.):
"""Initialize the AnimeGANV2 class. """Initialize the AnimeGANV2 class.
Parameters: Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
super(AnimeGANV2Model, self).__init__(cfg) super(AnimeGANV2Model, self).__init__()
self.g_adv_weight = g_adv_weight
self.d_adv_weight = d_adv_weight
self.con_weight = con_weight
self.sty_weight = sty_weight
self.color_weight = color_weight
self.tv_weight = tv_weight
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator) self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG']) init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.is_train: if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator) self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD']) init_weights(self.nets['netD'])
self.pretrained = CaffeVGG19() self.pretrained = CaffeVGG19()
self.losses = {} self.losses = {}
# define loss functions # define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode) self.criterionGAN = build_criterion(gan_criterion)
self.criterionL1 = nn.L1Loss() self.criterionL1 = nn.L1Loss()
self.criterionHub = nn.SmoothL1Loss() self.criterionHub = nn.SmoothL1Loss()
# build optimizers if pretrain_ckpt:
self.build_lr_scheduler() state_dicts = load(pretrain_ckpt)
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
if self.cfg.pretrain_ckpt:
state_dicts = load(self.cfg.pretrain_ckpt)
self.nets['netG'].set_state_dict(state_dicts['netG']) self.nets['netG'].set_state_dict(state_dicts['netG'])
print('Load pretrained generator from', self.cfg.pretrain_ckpt) print('Load pretrained generator from', pretrain_ckpt)
def set_input(self, input): def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
""" """
...@@ -152,13 +156,13 @@ class AnimeGANV2Model(BaseModel): ...@@ -152,13 +156,13 @@ class AnimeGANV2Model(BaseModel):
fake_logit = self.nets['netD'](self.fake.detach()) fake_logit = self.nets['netD'](self.fake.detach())
smooth_logit = self.nets['netD'](self.smooth_gray) smooth_logit = self.nets['netD'](self.smooth_gray)
d_real_loss = (self.cfg.d_adv_weight * 1.2 * d_real_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(real_logit, True)) self.criterionGAN(real_logit, True))
d_gray_loss = (self.cfg.d_adv_weight * 1.2 * d_gray_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(gray_logit, False)) self.criterionGAN(gray_logit, False))
d_fake_loss = (self.cfg.d_adv_weight * 1.2 * d_fake_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(fake_logit, False)) self.criterionGAN(fake_logit, False))
d_blur_loss = (self.cfg.d_adv_weight * 0.8 * d_blur_loss = (self.d_adv_weight * 0.8 *
self.criterionGAN(smooth_logit, False)) self.criterionGAN(smooth_logit, False))
self.loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss self.loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
...@@ -175,11 +179,11 @@ class AnimeGANV2Model(BaseModel): ...@@ -175,11 +179,11 @@ class AnimeGANV2Model(BaseModel):
fake_logit = self.nets['netD'](self.fake) fake_logit = self.nets['netD'](self.fake)
c_loss, s_loss = self.con_sty_loss(self.real, self.anime_gray, c_loss, s_loss = self.con_sty_loss(self.real, self.anime_gray,
self.fake) self.fake)
c_loss = self.cfg.con_weight * c_loss c_loss = self.con_weight * c_loss
s_loss = self.cfg.sty_weight * s_loss s_loss = self.sty_weight * s_loss
tv_loss = self.cfg.tv_weight * self.variation_loss(self.fake) tv_loss = self.tv_weight * self.variation_loss(self.fake)
col_loss = self.cfg.color_weight * self.color_loss(self.real, self.fake) col_loss = self.color_weight * self.color_loss(self.real, self.fake)
g_loss = (self.cfg.g_adv_weight * self.criterionGAN(fake_logit, True)) g_loss = (self.g_adv_weight * self.criterionGAN(fake_logit, True))
self.loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss self.loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
...@@ -191,7 +195,7 @@ class AnimeGANV2Model(BaseModel): ...@@ -191,7 +195,7 @@ class AnimeGANV2Model(BaseModel):
self.losses['col_loss'] = col_loss self.losses['col_loss'] = col_loss
self.losses['tv_loss'] = tv_loss self.losses['tv_loss'] = tv_loss
def optimize_parameters(self): def train_iter(self, optimizers=None):
# compute fake images: G(A) # compute fake images: G(A)
self.forward() self.forward()
...@@ -212,11 +216,11 @@ class AnimeGANV2PreTrainModel(AnimeGANV2Model): ...@@ -212,11 +216,11 @@ class AnimeGANV2PreTrainModel(AnimeGANV2Model):
real_feature_map = self.pretrained(self.real) real_feature_map = self.pretrained(self.real)
fake_feature_map = self.pretrained(self.fake) fake_feature_map = self.pretrained(self.fake)
init_c_loss = self.criterionL1(real_feature_map, fake_feature_map) init_c_loss = self.criterionL1(real_feature_map, fake_feature_map)
loss = self.cfg.con_weight * init_c_loss loss = self.con_weight * init_c_loss
loss.backward() loss.backward()
self.losses['init_c_loss'] = init_c_loss self.losses['init_c_loss'] = init_c_loss
def optimize_parameters(self): def train_iter(self, optimizers=None):
self.forward() self.forward()
# update G # update G
self.optimizers['optimizer_G'].clear_grad() self.optimizers['optimizer_G'].clear_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册