未验证 提交 cf153154 编写于 作者: L LielinJiang 提交者: GitHub

fix animegan (#158)

* adapt animegan

* clean code
Co-authored-by: Nqingqing01 <dangqingqing@baidu.com>
上级 6cee870f
epochs: 30
output_dir: output_dir
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
con_weight: 1.5
sty_weight: 2.5
color_weight: 10.
tv_weight: 1.
model:
name: AnimeGANV2Model
......@@ -14,7 +7,16 @@ model:
name: AnimeGenerator
discriminator:
name: AnimeDiscriminator
gan_mode: lsgan
gan_criterion:
name: GANLoss
gan_mode: lsgan
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
con_weight: 1.5
sty_weight: 2.5
color_weight: 10.
tv_weight: 1.
dataset:
train:
......@@ -23,8 +25,6 @@ dataset:
batch_size: 4
dataroot: data/animedataset
style: Hayao
phase: train
direction: AtoB
transform_real:
- name: Transpose
- name: Normalize
......@@ -63,15 +63,25 @@ dataset:
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.00002
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
epochs: 2
output_dir: output_dir
con_weight: 1
pretrain_ckpt: null
model:
name: AnimeGANV2PreTrainModel
......@@ -9,7 +7,11 @@ model:
name: AnimeGenerator
discriminator:
name: AnimeDiscriminator
gan_mode: lsgan
gan_criterion:
name: GANLoss
gan_mode: lsgan
con_weight: 1
pretrain_ckpt: null
dataset:
train:
......@@ -18,8 +20,6 @@ dataset:
batch_size: 4
dataroot: data/animedataset
style: Hayao
phase: train
direction: AtoB
transform_real:
- name: Transpose
- name: Normalize
......@@ -57,15 +57,25 @@ dataset:
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5
log_config:
interval: 100
......
......@@ -13,8 +13,9 @@
#limitations under the License.
import cv2
import numpy as np
import os.path
import numpy as np
import paddle
from .base_dataset import BaseDataset
from .image_folder import ImageFolder
......@@ -23,21 +24,27 @@ from .transforms.builder import build_transforms
@DATASETS.register()
class AnimeGANV2Dataset(BaseDataset):
class AnimeGANV2Dataset(paddle.io.Dataset):
"""
"""
def __init__(self, cfg):
def __init__(self,
dataroot,
style,
transform_real=None,
transform_anime=None,
transform_gray=None):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
BaseDataset.__init__(self, cfg)
self.style = cfg.style
# self.cfg = cfg
self.root = dataroot
self.style = style
self.transform_real = build_transforms(self.cfg.transform_real)
self.transform_anime = build_transforms(self.cfg.transform_anime)
self.transform_gray = build_transforms(self.cfg.transform_gray)
self.transform_real = build_transforms(transform_real)
self.transform_anime = build_transforms(transform_anime)
self.transform_gray = build_transforms(transform_gray)
self.real_root = os.path.join(self.root, 'train_photo')
self.anime_root = os.path.join(self.root, f'{self.style}', 'style')
......
......@@ -13,62 +13,66 @@
#limitations under the License.
import paddle
from paddle import nn
import paddle.nn as nn
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .criterions.gan_loss import GANLoss
from .criterions import build_criterion
from ..modules.caffevgg import CaffeVGG19
from ..solver import build_optimizer
from ..modules.init import init_weights
from ..utils.filesystem import load
@MODELS.register()
class AnimeGANV2Model(BaseModel):
def __init__(self, cfg):
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
pretrain_ckpt=None,
g_adv_weight=300.,
d_adv_weight=300.,
con_weight=1.5,
sty_weight=2.5,
color_weight=10.,
tv_weight=1.):
"""Initialize the AnimeGANV2 class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
super(AnimeGANV2Model, self).__init__(cfg)
super(AnimeGANV2Model, self).__init__()
self.g_adv_weight = g_adv_weight
self.d_adv_weight = d_adv_weight
self.con_weight = con_weight
self.sty_weight = sty_weight
self.color_weight = color_weight
self.tv_weight = tv_weight
# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])
self.pretrained = CaffeVGG19()
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionGAN = build_criterion(gan_criterion)
self.criterionL1 = nn.L1Loss()
self.criterionHub = nn.SmoothL1Loss()
# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())
if self.cfg.pretrain_ckpt:
state_dicts = load(self.cfg.pretrain_ckpt)
if pretrain_ckpt:
state_dicts = load(pretrain_ckpt)
self.nets['netG'].set_state_dict(state_dicts['netG'])
print('Load pretrained generator from', self.cfg.pretrain_ckpt)
print('Load pretrained generator from', pretrain_ckpt)
def set_input(self, input):
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
"""
......@@ -152,13 +156,13 @@ class AnimeGANV2Model(BaseModel):
fake_logit = self.nets['netD'](self.fake.detach())
smooth_logit = self.nets['netD'](self.smooth_gray)
d_real_loss = (self.cfg.d_adv_weight * 1.2 *
d_real_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(real_logit, True))
d_gray_loss = (self.cfg.d_adv_weight * 1.2 *
d_gray_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(gray_logit, False))
d_fake_loss = (self.cfg.d_adv_weight * 1.2 *
d_fake_loss = (self.d_adv_weight * 1.2 *
self.criterionGAN(fake_logit, False))
d_blur_loss = (self.cfg.d_adv_weight * 0.8 *
d_blur_loss = (self.d_adv_weight * 0.8 *
self.criterionGAN(smooth_logit, False))
self.loss_D = d_real_loss + d_gray_loss + d_fake_loss + d_blur_loss
......@@ -175,11 +179,11 @@ class AnimeGANV2Model(BaseModel):
fake_logit = self.nets['netD'](self.fake)
c_loss, s_loss = self.con_sty_loss(self.real, self.anime_gray,
self.fake)
c_loss = self.cfg.con_weight * c_loss
s_loss = self.cfg.sty_weight * s_loss
tv_loss = self.cfg.tv_weight * self.variation_loss(self.fake)
col_loss = self.cfg.color_weight * self.color_loss(self.real, self.fake)
g_loss = (self.cfg.g_adv_weight * self.criterionGAN(fake_logit, True))
c_loss = self.con_weight * c_loss
s_loss = self.sty_weight * s_loss
tv_loss = self.tv_weight * self.variation_loss(self.fake)
col_loss = self.color_weight * self.color_loss(self.real, self.fake)
g_loss = (self.g_adv_weight * self.criterionGAN(fake_logit, True))
self.loss_G = c_loss + s_loss + col_loss + g_loss + tv_loss
......@@ -191,7 +195,7 @@ class AnimeGANV2Model(BaseModel):
self.losses['col_loss'] = col_loss
self.losses['tv_loss'] = tv_loss
def optimize_parameters(self):
def train_iter(self, optimizers=None):
# compute fake images: G(A)
self.forward()
......@@ -212,11 +216,11 @@ class AnimeGANV2PreTrainModel(AnimeGANV2Model):
real_feature_map = self.pretrained(self.real)
fake_feature_map = self.pretrained(self.fake)
init_c_loss = self.criterionL1(real_feature_map, fake_feature_map)
loss = self.cfg.con_weight * init_c_loss
loss = self.con_weight * init_c_loss
loss.backward()
self.losses['init_c_loss'] = init_c_loss
def optimize_parameters(self):
def train_iter(self, optimizers=None):
self.forward()
# update G
self.optimizers['optimizer_G'].clear_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册