提交 2a9bd0d8 编写于 作者: Z zhumanyu 提交者: lvmengsi

add Pix2pix to gan library(#2296)

* add pix2pix to gan library
上级 21b00fb7
...@@ -22,6 +22,7 @@ import argparse ...@@ -22,6 +22,7 @@ import argparse
import struct import struct
import os import os
import paddle import paddle
import random
def RandomCrop(img, crop_w, crop_h): def RandomCrop(img, crop_w, crop_h):
...@@ -45,6 +46,18 @@ def RandomHorizonFlip(img): ...@@ -45,6 +46,18 @@ def RandomHorizonFlip(img):
return img return img
def get_preprocess_param(load_size, crop_size):
x = np.random.randint(0, np.maximum(0, load_size - crop_size))
y = np.random.randint(0, np.maximum(0, load_size - crop_size))
flip = np.random.rand() > 0.5
return {
"crop_pos": (x, y),
"flip": flip,
"load_size": load_size,
"crop_size": crop_size
}
class reader_creator(object): class reader_creator(object):
''' read and preprocess dataset''' ''' read and preprocess dataset'''
...@@ -122,6 +135,108 @@ class reader_creator(object): ...@@ -122,6 +135,108 @@ class reader_creator(object):
return reader return reader
class pair_reader_creator(reader_creator):
''' read and preprocess dataset'''
def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False):
super(pair_reader_creator, self).__init__(
image_dir, list_filename, batch_size=1, drop_last=drop_last)
def get_train_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out_1 = []
batch_out_2 = []
while True:
if shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
param = get_preprocess_param(args.load_size, args.crop_size)
img1 = img1.resize((args.load_size, args.load_size),
Image.BICUBIC)
img2 = img2.resize((args.load_size, args.load_size),
Image.BICUBIC)
if args.crop_type == 'Centor':
img1 = CentorCrop(img1, args.crop_size, args.crop_size)
img2 = CentorCrop(img2, args.crop_size, args.crop_size)
elif args.crop_type == 'Random':
x = param['crop_pos'][0]
y = param['crop_pos'][1]
img1 = img1.crop(
(x, y, x + args.crop_size, y + args.crop_size))
img2 = img2.crop(
(x, y, x + args.crop_size, y + args.crop_size))
img1 = (
np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5
img1 = img1.transpose([2, 0, 1])
img2 = (
np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
batch_out_1.append(img1)
batch_out_2.append(img2)
if len(batch_out_1) == self.batch_size:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
yield batch_out_1, batch_out_2
return reader
def get_test_reader(self, args, shuffle=False, return_name=False):
print(self.image_dir, self.list_filename)
def reader():
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
for line in self.lines:
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
img1 = img1.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img2 = img2.resize((args.crop_size, args.crop_size),
Image.BICUBIC)
img1 = (np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5
img1 = img1.transpose([2, 0, 1])
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
if return_name:
batch_out_1.append(img1)
batch_out_2.append(img2)
batch_out_3.append(os.path.basename(files[0]))
else:
batch_out_1.append(img1)
batch_out_2.append(img2)
if len(batch_out_1) == self.batch_size:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
else:
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_3
else:
yield batch_out_1, batch_out_2
return reader
def mnist_reader_creator(image_filename, label_filename, buffer_size): def mnist_reader_creator(image_filename, label_filename, buffer_size):
def reader(): def reader():
with gzip.GzipFile(image_filename, 'rb') as image_file: with gzip.GzipFile(image_filename, 'rb') as image_file:
...@@ -231,6 +346,32 @@ class data_reader(object): ...@@ -231,6 +346,32 @@ class data_reader(object):
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
elif self.cfg.model_net == 'Pix2pix':
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
train_list = self.cfg.train_list
train_reader = pair_reader_creator(
image_dir=dataset_dir,
list_filename=train_list,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last)
reader_test = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
if self.cfg.test_list is not None:
test_list = self.cfg.test_list
test_reader = pair_reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=1,
drop_last=self.cfg.drop_last)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
batch_num = train_reader.len()
reader = train_reader.get_train_reader(
self.cfg, shuffle=self.shuffle)
return reader, reader_test, batch_num
else: else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
......
...@@ -57,6 +57,11 @@ def infer(args): ...@@ -57,6 +57,11 @@ def infer(args):
fake = network_G(input, name="GB", cfg=args) fake = network_G(input, name="GB", cfg=args)
else: else:
raise "Input with style [%s] is not supported." % args.input_style raise "Input with style [%s] is not supported." % args.input_style
elif args.model_net == 'Pix2pix':
from network.Pix2pix_network import Pix2pix_model
model = Pix2pix_model()
fake = model.network_G(input, "generator", cfg=args)
elif args.model_net == 'cgan': elif args.model_net == 'cgan':
pass pass
else: else:
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .base_network import conv2d, deconv2d, norm_layer
import paddle.fluid as fluid
class Pix2pix_model(object):
def __init__(self):
pass
def network_G(self, input, name, cfg):
if cfg.net_G == 'resnet_9block':
net = build_generator_resnet_blocks(
input,
name=name + "_resnet9block",
n_gen_res=9,
g_base_dims=cfg.g_base_dims,
use_dropout=cfg.dropout,
norm_type=cfg.norm_type)
elif cfg.net_G == 'resnet_6block':
net = build_generator_resnet_blocks(
input,
name=name + "_resnet6block",
n_gen_res=6,
g_base_dims=cfg.g_base_dims,
use_dropout=cfg.dropout,
norm_type=cfg.norm_type)
elif cfg.net_G == 'unet_128':
net = build_generator_Unet(
input,
name=name + "_unet128",
num_downsample=7,
g_base_dims=cfg.g_base_dims,
use_dropout=cfg.dropout,
norm_type=cfg.norm_type)
elif cfg.net_G == 'unet_256':
net = build_generator_Unet(
input,
name=name + "_unet256",
num_downsample=8,
g_base_dims=cfg.g_base_dims,
use_dropout=cfg.dropout,
norm_type=cfg.norm_type)
else:
raise NotImplementedError(
'network G: [%s] is wrong format, please check it' % cfg.net_G)
return net
def network_D(self, input, name, cfg):
if cfg.net_D == 'basic':
net = build_discriminator_Nlayers(
input,
name=name + '_basic',
d_nlayers=3,
d_base_dims=cfg.d_base_dims,
norm_type=cfg.norm_type)
elif cfg.net_D == 'nlayers':
net = build_discriminator_Nlayers(
input,
name=name + '_nlayers',
d_nlayers=cfg.d_nlayers,
d_base_dims=cfg.d_base_dims,
norm_type=cfg.norm_type)
elif cfg.net_D == 'pixel':
net = build_discriminator_Pixel(
input,
name=name + '_pixel',
d_base_dims=cfg.d_base_dims,
norm_type=cfg.norm_type)
else:
raise NotImplementedError(
'network D: [%s] is wrong format, please check it' % cfg.net_D)
return net
def build_resnet_block(inputres,
dim,
name="resnet",
use_bias=False,
use_dropout=False,
norm_type='batch_norm'):
out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect")
out_res = conv2d(
out_res,
dim,
3,
1,
0.02,
name=name + "_c1",
norm=norm_type,
activation_fn='relu',
use_bias=use_bias)
if use_dropout:
out_res = fluid.layers.dropout(out_res, dropout_prob=0.5)
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = conv2d(
out_res,
dim,
3,
1,
0.02,
name=name + "_c2",
norm=norm_type,
use_bias=use_bias)
return out_res + inputres
def build_generator_resnet_blocks(inputgen,
name="generator",
n_gen_res=9,
g_base_dims=64,
use_dropout=False,
norm_type='batch_norm'):
''' generator use resnet block'''
'''The shape of input should be equal to the shape of output.'''
use_bias = norm_type == 'instance_norm'
pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect")
o_c1 = conv2d(
pad_input,
g_base_dims,
7,
1,
0.02,
name=name + "_c1",
norm=norm_type,
activation_fn='relu')
o_c2 = conv2d(
o_c1,
g_base_dims * 2,
3,
2,
0.02,
1,
name=name + "_c2",
norm=norm_type,
activation_fn='relu')
res_input = conv2d(
o_c2,
g_base_dims * 4,
3,
2,
0.02,
1,
name=name + "_c3",
norm=norm_type,
activation_fn='relu')
for i in xrange(n_gen_res):
conv_name = name + "_r{}".format(i + 1)
res_output = build_resnet_block(
res_input,
g_base_dims * 4,
name=conv_name,
use_bias=use_bias,
use_dropout=use_dropout)
res_input = res_output
o_c4 = deconv2d(
res_output,
g_base_dims * 2,
3,
2,
0.02, [1, 1], [0, 1, 0, 1],
name=name + "_c4",
norm=norm_type,
activation_fn='relu')
o_c5 = deconv2d(
o_c4,
g_base_dims,
3,
2,
0.02, [1, 1], [0, 1, 0, 1],
name=name + "_c5",
norm=norm_type,
activation_fn='relu')
o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect")
o_c6 = conv2d(
o_p2,
3,
7,
1,
0.02,
name=name + "_c6",
activation_fn='tanh',
use_bias=True)
return o_c6
def Unet_block(inputunet,
i,
outer_dim,
inner_dim,
num_downsample,
innermost=False,
outermost=False,
norm_type='batch_norm',
use_bias=False,
use_dropout=False,
name=None):
if outermost == True:
downconv = conv2d(
inputunet,
inner_dim,
4,
2,
0.02,
1,
name=name + '_outermost_dc1',
use_bias=True)
i += 1
mid_block = Unet_block(
downconv,
i,
inner_dim,
inner_dim * 2,
num_downsample,
norm_type=norm_type,
use_bias=use_bias,
use_dropout=use_dropout,
name=name)
uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu')
updeconv = deconv2d(
uprelu,
outer_dim,
4,
2,
0.02,
1,
name=name + '_outermost_uc1',
activation_fn='tanh',
use_bias=use_bias)
return updeconv
elif innermost == True:
downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_innermost_leaky_relu')
upconv = conv2d(
downrelu,
inner_dim,
4,
2,
0.02,
1,
name=name + '_innermost_dc1',
activation_fn='relu',
use_bias=use_bias)
updeconv = deconv2d(
upconv,
outer_dim,
4,
2,
0.02,
1,
name=name + '_innermost_uc1',
norm=norm_type,
use_bias=use_bias)
return fluid.layers.concat([inputunet, updeconv], 1)
else:
downrelu = fluid.layers.leaky_relu(
inputunet, 0.2, name=name + '_leaky_relu')
downnorm = conv2d(
downrelu,
inner_dim,
4,
2,
0.02,
1,
name=name + 'dc1',
norm=norm_type,
use_bias=use_bias)
i += 1
if i < 4:
mid_block = Unet_block(
downnorm,
i,
inner_dim,
inner_dim * 2,
num_downsample,
norm_type=norm_type,
use_bias=use_bias,
name=name + '_mid{}'.format(i))
elif i < num_downsample - 1:
mid_block = Unet_block(
downnorm,
i,
inner_dim,
inner_dim,
num_downsample,
norm_type=norm_type,
use_bias=use_bias,
use_dropout=use_dropout,
name=name + '_mid{}'.format(i))
else:
mid_block = Unet_block(
downnorm,
i,
inner_dim,
inner_dim,
num_downsample,
innermost=True,
norm_type=norm_type,
use_bias=use_bias,
name=name + '_innermost')
uprelu = fluid.layers.relu(mid_block, name=name + '_relu')
updeconv = deconv2d(
uprelu,
outer_dim,
4,
2,
0.02,
1,
name=name + '_uc1',
norm=norm_type,
use_bias=use_bias)
if use_dropout:
upnorm = fluid.layers.dropout(upnorm, dropout_prob=0.5)
return fluid.layers.concat([inputunet, updeconv], 1)
def UnetSkipConnectionBlock(input,
i,
num_downs,
outer_nc,
inner_nc,
outermost=False,
innermost=False,
norm='batch_norm',
use_dropout=False,
name=""):
use_bias = norm == "instance"
if outermost:
downconv = conv2d(
input,
inner_nc,
4,
2,
padding=1,
use_bias=use_bias,
name=name + '_down_conv')
i += 1
ngf = inner_nc
sub_res = UnetSkipConnectionBlock(
downconv,
i,
num_downs,
outer_nc=ngf,
inner_nc=ngf * 2,
norm=norm,
name=name + '_u%d' % i)
uprelu = fluid.layers.relu(sub_res)
upconv = deconv2d(
uprelu,
outer_nc,
4,
2,
padding=1,
activation_fn='tanh',
name=name + '_up_conv')
return upconv
elif innermost:
downrelu = fluid.layers.leaky_relu(input, 0.2)
downconv = conv2d(
downrelu,
inner_nc,
4,
2,
padding=1,
use_bias=use_bias,
name=name + '_down_conv')
uprelu = fluid.layers.relu(downconv)
upconv = deconv2d(
uprelu,
outer_nc,
4,
2,
padding=1,
use_bias=use_bias,
norm=norm,
name=name + '_up_conv')
return fluid.layers.concat([input, upconv], 1)
else:
downrelu = fluid.layers.leaky_relu(input, 0.2)
downconv = conv2d(
downrelu,
inner_nc,
4,
2,
padding=1,
use_bias=use_bias,
norm=norm,
name=name + '_down_conv')
i += 1
ngf = inner_nc
if i < 4:
sub_res = UnetSkipConnectionBlock(
downconv,
i,
num_downs,
outer_nc=ngf,
inner_nc=ngf * 2,
norm=norm,
name=name + '_u%d' % i)
elif i < num_downs - 1:
sub_res = UnetSkipConnectionBlock(
downconv,
i,
num_downs,
outer_nc=ngf,
inner_nc=ngf,
norm=norm,
name=name + '_u%d' % i)
else:
sub_res = UnetSkipConnectionBlock(
downconv,
i,
num_downs,
outer_nc=ngf,
inner_nc=ngf,
innermost=True,
norm=norm,
name=name + '_u%d' % i)
uprelu = fluid.layers.relu(sub_res)
upconv = deconv2d(
uprelu,
outer_nc,
4,
2,
padding=1,
use_bias=use_bias,
norm=norm,
name=name + '_up_conv')
out = upconv
if use_dropout:
out = fluid.layers.dropout(out, 0.5)
return fluid.layers.concat([input, out], 1)
def build_generator_Unet(input,
name="",
num_downsample=8,
g_base_dims=64,
use_dropout=False,
norm_type='batch_norm'):
''' generator use Unet'''
i = 0
output = UnetSkipConnectionBlock(
input,
i,
num_downsample,
3,
g_base_dims,
outermost=True,
norm=norm_type,
name=name + '_u%d' % i)
return output
def build_discriminator_Nlayers(inputdisc,
name="discriminator",
d_nlayers=3,
d_base_dims=64,
norm_type='batch_norm'):
use_bias = norm_type != 'batch_norm'
dis_input = conv2d(
inputdisc,
d_base_dims,
4,
2,
0.02,
1,
name=name + "_c1",
activation_fn='leaky_relu',
relufactor=0.2,
use_bias=True)
d_dims = d_base_dims
for i in xrange(d_nlayers - 1):
conv_name = name + "_c{}".format(i + 2)
d_dims *= 2
dis_output = conv2d(
dis_input,
d_dims,
4,
2,
0.02,
1,
name=conv_name,
norm=norm_type,
activation_fn='leaky_relu',
relufactor=0.2,
use_bias=use_bias)
dis_input = dis_output
last_dims = min(2**d_nlayers, 8)
o_c4 = conv2d(
dis_output,
d_base_dims * last_dims,
4,
1,
0.02,
1,
name + "_c{}".format(d_nlayers + 1),
norm=norm_type,
activation_fn='leaky_relu',
relufactor=0.2,
use_bias=use_bias)
o_c5 = conv2d(
o_c4,
1,
4,
1,
0.02,
1,
name + "_c{}".format(d_nlayers + 2),
use_bias=True)
return o_c5
def build_discriminator_Pixel(inputdisc,
name="discriminator",
d_base_dims=64,
norm_type='batch_norm'):
use_bias = norm_type != 'instance_norm'
o_c1 = conv2d(
inputdisc,
d_base_dims,
1,
1,
0.02,
name=name + '_c1',
activation_fn='leaky_relu',
relufactor=0.2,
use_bias=True)
o_c2 = conv2d(
o_c1,
d_base_dims * 2,
1,
1,
0.02,
name=name + '_c2',
norm=norm_type,
activation_fn='leaky_relu',
relufactor=0.2,
use_bias=use_bias)
o_c3 = conv2d(o_c2, 1, 1, 1, 0.02, name=name + '_c3', use_bias=use_bias)
return o_c3
python infer.py --init_model output/chechpoints/15/ --input data/cityscapes/test/B/100.jpg --model_net Pix2pix --net_G unet_256
python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list10 --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err
...@@ -31,6 +31,8 @@ def train(cfg): ...@@ -31,6 +31,8 @@ def train(cfg):
if cfg.model_net == 'CycleGAN': if cfg.model_net == 'CycleGAN':
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data( a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
) )
elif cfg.model_net == 'Pix2pix':
train_reader, test_reader, batch_num = reader.make_data()
else: else:
if cfg.dataset == 'mnist': if cfg.dataset == 'mnist':
train_reader = reader.make_data() train_reader = reader.make_data()
...@@ -51,6 +53,9 @@ def train(cfg): ...@@ -51,6 +53,9 @@ def train(cfg):
from trainer.CycleGAN import CycleGAN from trainer.CycleGAN import CycleGAN
model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test, model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test,
batch_num) batch_num)
elif cfg.model_net == 'Pix2pix':
from trainer.Pix2pix import Pix2pix
model = Pix2pix(cfg, train_reader, test_reader, batch_num)
else: else:
pass pass
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model
from util import utility
import paddle.fluid as fluid
import sys
import time
class GTrainer():
def __init__(self, input_A, input_B, cfg, step_per_epoch):
self.program = fluid.default_main_program().clone()
with fluid.program_guard(self.program):
model = Pix2pix_model()
self.fake_B = model.network_G(input_A, "generator", cfg=cfg)
self.fake_B.persistable = True
self.infer_program = self.program.clone()
AB = fluid.layers.concat([input_A, self.fake_B], 1)
self.pred = model.network_D(AB, "discriminator", cfg)
if cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred,
shape=self.pred.shape,
value=1,
dtype='float32')
self.g_loss_gan = fluid.layers.reduce_mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
x=self.pred, y=ones)))
elif cfg.gan_mode == "vanilla":
pred_shape = self.pred.shape
self.pred = fluid.layers.reshape(
self.pred,
[-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
inplace=True)
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred,
shape=self.pred.shape,
value=1,
dtype='float32')
self.g_loss_gan = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.pred, label=ones))
self.g_loss_L1 = fluid.layers.reduce_mean(
fluid.layers.abs(
fluid.layers.elementwise_sub(
x=input_B, y=self.fake_B))) * cfg.lambda_L1
self.g_loss = fluid.layers.elementwise_add(self.g_loss_L1,
self.g_loss_gan)
lr = cfg.learning_rate
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith(
"generator"):
vars.append(var.name)
self.param = vars
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
]),
beta1=0.5,
beta2=0.999,
name="net_G")
optimizer.minimize(self.g_loss, parameter_list=vars)
class DTrainer():
def __init__(self, input_A, input_B, fake_B, cfg, step_per_epoch):
self.program = fluid.default_main_program().clone()
lr = cfg.learning_rate
with fluid.program_guard(self.program):
model = Pix2pix_model()
self.real_AB = fluid.layers.concat([input_A, input_B], 1)
self.fake_AB = fluid.layers.concat([input_A, fake_B], 1)
self.pred_real = model.network_D(
self.real_AB, "discriminator", cfg=cfg)
self.pred_fake = model.network_D(
self.fake_AB, "discriminator", cfg=cfg)
if cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
value=1,
dtype='float32')
self.d_loss_real = fluid.layers.reduce_mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
x=self.pred_real, y=ones)))
self.d_loss_fake = fluid.layers.reduce_mean(
fluid.layers.square(x=self.pred_fake))
elif cfg.gan_mode == "vanilla":
pred_shape = self.pred_real.shape
self.pred_real = fluid.layers.reshape(
self.pred_real,
[-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
inplace=True)
self.pred_fake = fluid.layers.reshape(
self.pred_fake,
[-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
inplace=True)
zeros = fluid.layers.fill_constant_batch_size_like(
input=self.pred_fake,
shape=self.pred_fake.shape,
value=0,
dtype='float32')
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
value=1,
dtype='float32')
self.d_loss_real = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.pred_real, label=ones))
self.d_loss_fake = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.pred_fake, label=zeros))
self.d_loss = 0.5 * (self.d_loss_real + self.d_loss_fake)
vars = []
for var in self.program.list_vars():
if fluid.io.is_parameter(var) and var.name.startswith(
"discriminator"):
vars.append(var.name)
self.param = vars
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=[99 * step_per_epoch] +
[x * step_per_epoch for x in range(100, cfg.epoch - 1)],
values=[lr] + [
lr * (1.0 - (x - 99.0) / 101.0)
for x in range(100, cfg.epoch)
]),
beta1=0.5,
beta2=0.999,
name="net_D")
optimizer.minimize(self.d_loss, parameter_list=vars)
class Pix2pix(object):
def add_special_args(self, parser):
parser.add_argument(
'--net_G',
type=str,
default="unet_256",
help="Choose the Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]"
)
parser.add_argument(
'--net_D',
type=str,
default="basic",
help="Choose the Pix2pix discriminator's network, choose in [basic|nlayers|pixel]"
)
parser.add_argument(
'--d_nlayers',
type=int,
default=3,
help="only used when Pix2pix discriminator is nlayers")
return parser
def __init__(self,
cfg=None,
train_reader=None,
test_reader=None,
batch_num=1):
self.cfg = cfg
self.train_reader = train_reader
self.test_reader = test_reader
self.batch_num = batch_num
def build_model(self):
data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size]
input_A = fluid.layers.data(
name='input_A', shape=data_shape, dtype='float32')
input_B = fluid.layers.data(
name='input_B', shape=data_shape, dtype='float32')
input_fake = fluid.layers.data(
name='input_fake', shape=data_shape, dtype='float32')
gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num)
dis_trainer = DTrainer(input_A, input_B, input_fake, self.cfg,
self.batch_num)
# prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if self.cfg.init_model:
utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D")
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel(
loss_name=gen_trainer.g_loss.name,
build_strategy=build_strategy)
dis_trainer_program = fluid.CompiledProgram(
dis_trainer.program).with_data_parallel(
loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy)
t_time = 0
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for i in range(self.batch_num):
data_A, data_B = next(self.train_reader())
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
s_time = time.time()
# optimize the generator network
g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
gen_trainer_program,
fetch_list=[
gen_trainer.g_loss_gan, gen_trainer.g_loss_L1,
gen_trainer.fake_B
],
feed={"input_A": tensor_A,
"input_B": tensor_B})
# optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
fetch_list=[
dis_trainer.d_loss_real,
dis_trainer.d_loss_fake
],
feed={
"input_A": tensor_A,
"input_B": tensor_B,
"input_fake": fake_B_tmp
})
batch_time = time.time() - s_time
t_time += batch_time
if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
g_loss_gan: {}; g_loss_l1: {}; \n\
d_loss_real: {}; d_loss_fake: {}; \n\
Batch_time_cost: {:.2f}"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[
0], d_loss_real[0], d_loss_fake[0], batch_time))
sys.stdout.flush()
batch_id += 1
if self.cfg.run_test:
test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer,
self.test_reader)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer,
"net_D")
...@@ -71,7 +71,9 @@ def base_parse_args(parser): ...@@ -71,7 +71,9 @@ def base_parse_args(parser):
add_arg('model_net', str, "cgan", "The model used.") add_arg('model_net', str, "cgan", "The model used.")
add_arg('dataset', str, "mnist", "The dataset used.") add_arg('dataset', str, "mnist", "The dataset used.")
add_arg('data_dir', str, "./data", "The dataset root directory") add_arg('data_dir', str, "./data", "The dataset root directory")
add_arg('data_list', str, None, "The dataset list file name") add_arg('data_list', str, "data/cityscapes/pix2pix_train_list", "The data list file name")
add_arg('train_list', str, "data/cityscapes/pix2pix_train_list", "The train list file name")
add_arg('test_list', str, "data/cityscapes/pix2pix_test_list10", "The test list file name")
add_arg('batch_size', int, 1, "Minibatch size.") add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('epoch', int, 200, "The number of epoch to be trained.") add_arg('epoch', int, 200, "The number of epoch to be trained.")
add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator") add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator")
...@@ -85,15 +87,16 @@ def base_parse_args(parser): ...@@ -85,15 +87,16 @@ def base_parse_args(parser):
add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('profile', bool, False, "Whether to profile.") add_arg('profile', bool, False, "Whether to profile.")
add_arg('dropout', bool, False, "Whether to use drouput.") add_arg('dropout', bool, False, "Whether to use drouput.")
add_arg('use_dropout', bool, False, "Whether to use dropout")
add_arg('drop_last', bool, False, add_arg('drop_last', bool, False,
"Whether to drop the last images that cannot form a batch") "Whether to drop the last images that cannot form a batch")
add_arg('shuffle', bool, True, "Whether to shuffle data") add_arg('shuffle', bool, True, "Whether to shuffle data")
add_arg('output', str, "./output", add_arg('output', str, "./output",
"The directory the model and the test result to be saved to.") "The directory the model and the test result to be saved to.")
add_arg('init_model', str, None, "The init model file of directory.") add_arg('init_model', str, None, "The init model file of directory.")
add_arg('gan_mode', str, "vanilla", "The init model file of directory.")
add_arg('norm_type', str, "batch_norm", "Which normalization to used") add_arg('norm_type', str, "batch_norm", "Which normalization to used")
add_arg('learning_rate', int, 0.0002, "the initialize learning rate") add_arg('learning_rate', float, 0.0002, "the initialize learning rate")
add_arg('lambda_L1', float, 100.0, "the initialize learning rate")
add_arg('num_generator_time', int, 1, add_arg('num_generator_time', int, 1,
"the generator run times in training each epoch") "the generator run times in training each epoch")
add_arg('print_freq', int, 10, "the frequency of print loss") add_arg('print_freq', int, 10, "the frequency of print loss")
......
...@@ -66,45 +66,75 @@ def init_checkpoints(cfg, exe, trainer, name): ...@@ -66,45 +66,75 @@ def init_checkpoints(cfg, exe, trainer, name):
sys.stdout.flush() sys.stdout.flush()
def save_test_image(epoch, cfg, exe, place, test_program, g_trainer, def save_test_image(epoch,
A_test_reader, B_test_reader): cfg,
exe,
place,
test_program,
g_trainer,
A_test_reader,
B_test_reader=None):
out_path = cfg.output + '/test' out_path = cfg.output + '/test'
if not os.path.exists(out_path): if not os.path.exists(out_path):
os.makedirs(out_path) os.makedirs(out_path)
for data_A, data_B in zip(A_test_reader(), B_test_reader()): if B_test_reader is None:
A_name = data_A[0][1] for data in zip(A_test_reader()):
B_name = data_B[0][1] data_A, data_B, name = data[0]
tensor_A = fluid.LoDTensor() name = name[0]
tensor_B = fluid.LoDTensor() tensor_A = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place) tensor_B = fluid.LoDTensor()
tensor_B.set(data_B[0][0], place) tensor_A.set(data_A, place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( tensor_B.set(data_B, place)
test_program, fake_B_temp = exe.run(
fetch_list=[ test_program,
g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, fetch_list=[g_trainer.fake_B],
g_trainer.cyc_B feed={"input_A": tensor_A,
], "input_B": tensor_B})
feed={"input_A": tensor_A, fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
"input_B": tensor_B}) input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( imsave(out_path + "/fakeB_" + str(epoch) + "_" + name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8)) (fake_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( imsave(out_path + "/inputA_" + str(epoch) + "_" + name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8)) (input_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( imsave(out_path + "/inputB_" + str(epoch) + "_" + name, (
(cyc_A_temp + 1) * 127.5).astype(np.uint8)) (input_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( else:
(cyc_B_temp + 1) * 127.5).astype(np.uint8)) for data_A, data_B in zip(A_test_reader(), B_test_reader()):
imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( A_name = data_A[0][1]
(input_A_temp + 1) * 127.5).astype(np.uint8)) B_name = data_B[0][1]
imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( tensor_A = fluid.LoDTensor()
(input_B_temp + 1) * 127.5).astype(np.uint8)) tensor_B = fluid.LoDTensor()
tensor_A.set(data_A[0][0], place)
tensor_B.set(data_B[0][0], place)
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program,
fetch_list=[
g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A,
g_trainer.cyc_B
],
feed={"input_A": tensor_A,
"input_B": tensor_B})
fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, (
(fake_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, (
(fake_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, (
(cyc_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, (
(cyc_B_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, (
(input_B_temp + 1) * 127.5).astype(np.uint8))
class ImagePool(object): class ImagePool(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册