From 2a9bd0d89ea14c4a05bf1b2ec28dea3618631def Mon Sep 17 00:00:00 2001 From: zhumanyu <259571082@qq.com> Date: Wed, 29 May 2019 10:21:22 +0800 Subject: [PATCH] add Pix2pix to gan library(#2296) * add pix2pix to gan library --- PaddleCV/gan/data_reader.py | 141 ++++++ PaddleCV/gan/infer.py | 5 + PaddleCV/gan/network/Pix2pix_network.py | 565 ++++++++++++++++++++++++ PaddleCV/gan/scripts/infer_pix2pix.sh | 1 + PaddleCV/gan/scripts/run_pix2pix.sh | 1 + PaddleCV/gan/train.py | 5 + PaddleCV/gan/trainer/Pix2pix.py | 285 ++++++++++++ PaddleCV/gan/util/config.py | 9 +- PaddleCV/gan/util/utility.py | 100 +++-- 9 files changed, 1074 insertions(+), 38 deletions(-) create mode 100644 PaddleCV/gan/network/Pix2pix_network.py create mode 100644 PaddleCV/gan/scripts/infer_pix2pix.sh create mode 100644 PaddleCV/gan/scripts/run_pix2pix.sh create mode 100644 PaddleCV/gan/trainer/Pix2pix.py diff --git a/PaddleCV/gan/data_reader.py b/PaddleCV/gan/data_reader.py index 98d507b9..46fc91a4 100644 --- a/PaddleCV/gan/data_reader.py +++ b/PaddleCV/gan/data_reader.py @@ -22,6 +22,7 @@ import argparse import struct import os import paddle +import random def RandomCrop(img, crop_w, crop_h): @@ -45,6 +46,18 @@ def RandomHorizonFlip(img): return img +def get_preprocess_param(load_size, crop_size): + x = np.random.randint(0, np.maximum(0, load_size - crop_size)) + y = np.random.randint(0, np.maximum(0, load_size - crop_size)) + flip = np.random.rand() > 0.5 + return { + "crop_pos": (x, y), + "flip": flip, + "load_size": load_size, + "crop_size": crop_size + } + + class reader_creator(object): ''' read and preprocess dataset''' @@ -122,6 +135,108 @@ class reader_creator(object): return reader +class pair_reader_creator(reader_creator): + ''' read and preprocess dataset''' + + def __init__(self, image_dir, list_filename, batch_size=1, drop_last=False): + super(pair_reader_creator, self).__init__( + image_dir, list_filename, batch_size=1, drop_last=drop_last) + + def get_train_reader(self, args, shuffle=False, return_name=False): + print(self.image_dir, self.list_filename) + + def reader(): + batch_out_1 = [] + batch_out_2 = [] + while True: + if shuffle: + np.random.shuffle(self.lines) + for line in self.lines: + files = line.strip('\n\r\t ').split('\t') + img1 = Image.open(os.path.join(self.image_dir, files[ + 0])).convert('RGB') + img2 = Image.open(os.path.join(self.image_dir, files[ + 1])).convert('RGB') + param = get_preprocess_param(args.load_size, args.crop_size) + img1 = img1.resize((args.load_size, args.load_size), + Image.BICUBIC) + img2 = img2.resize((args.load_size, args.load_size), + Image.BICUBIC) + if args.crop_type == 'Centor': + img1 = CentorCrop(img1, args.crop_size, args.crop_size) + img2 = CentorCrop(img2, args.crop_size, args.crop_size) + elif args.crop_type == 'Random': + x = param['crop_pos'][0] + y = param['crop_pos'][1] + img1 = img1.crop( + (x, y, x + args.crop_size, y + args.crop_size)) + img2 = img2.crop( + (x, y, x + args.crop_size, y + args.crop_size)) + img1 = ( + np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5 + img1 = img1.transpose([2, 0, 1]) + img2 = ( + np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5 + img2 = img2.transpose([2, 0, 1]) + + batch_out_1.append(img1) + batch_out_2.append(img2) + if len(batch_out_1) == self.batch_size: + yield batch_out_1, batch_out_2 + batch_out_1 = [] + batch_out_2 = [] + if self.drop_last == False and len(batch_out_1) != 0: + yield batch_out_1, batch_out_2 + + return reader + + def get_test_reader(self, args, shuffle=False, return_name=False): + print(self.image_dir, self.list_filename) + + def reader(): + batch_out_1 = [] + batch_out_2 = [] + batch_out_3 = [] + for line in self.lines: + files = line.strip('\n\r\t ').split('\t') + img1 = Image.open(os.path.join(self.image_dir, files[ + 0])).convert('RGB') + img2 = Image.open(os.path.join(self.image_dir, files[ + 1])).convert('RGB') + img1 = img1.resize((args.crop_size, args.crop_size), + Image.BICUBIC) + img2 = img2.resize((args.crop_size, args.crop_size), + Image.BICUBIC) + img1 = (np.array(img1).astype('float32') / 255.0 - 0.5) / 0.5 + img1 = img1.transpose([2, 0, 1]) + img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5 + img2 = img2.transpose([2, 0, 1]) + if return_name: + batch_out_1.append(img1) + batch_out_2.append(img2) + batch_out_3.append(os.path.basename(files[0])) + else: + batch_out_1.append(img1) + batch_out_2.append(img2) + if len(batch_out_1) == self.batch_size: + if return_name: + yield batch_out_1, batch_out_2, batch_out_3 + batch_out_1 = [] + batch_out_2 = [] + batch_out_3 = [] + else: + yield batch_out_1, batch_out_2 + batch_out_1 = [] + batch_out_2 = [] + if len(batch_out_1) != 0: + if return_name: + yield batch_out_1, batch_out_2, batch_out_3 + else: + yield batch_out_1, batch_out_2 + + return reader + + def mnist_reader_creator(image_filename, label_filename, buffer_size): def reader(): with gzip.GzipFile(image_filename, 'rb') as image_file: @@ -231,6 +346,32 @@ class data_reader(object): return a_reader, b_reader, a_reader_test, b_reader_test, batch_num + elif self.cfg.model_net == 'Pix2pix': + dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) + train_list = os.path.join(dataset_dir, 'train.txt') + if self.cfg.train_list is not None: + train_list = self.cfg.train_list + train_reader = pair_reader_creator( + image_dir=dataset_dir, + list_filename=train_list, + batch_size=self.cfg.batch_size, + drop_last=self.cfg.drop_last) + reader_test = None + if self.cfg.run_test: + test_list = os.path.join(dataset_dir, "test.txt") + if self.cfg.test_list is not None: + test_list = self.cfg.test_list + test_reader = pair_reader_creator( + image_dir=dataset_dir, + list_filename=test_list, + batch_size=1, + drop_last=self.cfg.drop_last) + reader_test = test_reader.get_test_reader( + self.cfg, shuffle=False, return_name=True) + batch_num = train_reader.len() + reader = train_reader.get_train_reader( + self.cfg, shuffle=self.shuffle) + return reader, reader_test, batch_num else: dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) train_list = os.path.join(dataset_dir, 'train.txt') diff --git a/PaddleCV/gan/infer.py b/PaddleCV/gan/infer.py index ab7bc6c6..3d811936 100644 --- a/PaddleCV/gan/infer.py +++ b/PaddleCV/gan/infer.py @@ -57,6 +57,11 @@ def infer(args): fake = network_G(input, name="GB", cfg=args) else: raise "Input with style [%s] is not supported." % args.input_style + elif args.model_net == 'Pix2pix': + from network.Pix2pix_network import Pix2pix_model + model = Pix2pix_model() + fake = model.network_G(input, "generator", cfg=args) + elif args.model_net == 'cgan': pass else: diff --git a/PaddleCV/gan/network/Pix2pix_network.py b/PaddleCV/gan/network/Pix2pix_network.py new file mode 100644 index 00000000..6ba949d4 --- /dev/null +++ b/PaddleCV/gan/network/Pix2pix_network.py @@ -0,0 +1,565 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .base_network import conv2d, deconv2d, norm_layer +import paddle.fluid as fluid + + +class Pix2pix_model(object): + def __init__(self): + pass + + def network_G(self, input, name, cfg): + if cfg.net_G == 'resnet_9block': + net = build_generator_resnet_blocks( + input, + name=name + "_resnet9block", + n_gen_res=9, + g_base_dims=cfg.g_base_dims, + use_dropout=cfg.dropout, + norm_type=cfg.norm_type) + elif cfg.net_G == 'resnet_6block': + net = build_generator_resnet_blocks( + input, + name=name + "_resnet6block", + n_gen_res=6, + g_base_dims=cfg.g_base_dims, + use_dropout=cfg.dropout, + norm_type=cfg.norm_type) + elif cfg.net_G == 'unet_128': + net = build_generator_Unet( + input, + name=name + "_unet128", + num_downsample=7, + g_base_dims=cfg.g_base_dims, + use_dropout=cfg.dropout, + norm_type=cfg.norm_type) + elif cfg.net_G == 'unet_256': + net = build_generator_Unet( + input, + name=name + "_unet256", + num_downsample=8, + g_base_dims=cfg.g_base_dims, + use_dropout=cfg.dropout, + norm_type=cfg.norm_type) + else: + raise NotImplementedError( + 'network G: [%s] is wrong format, please check it' % cfg.net_G) + return net + + def network_D(self, input, name, cfg): + if cfg.net_D == 'basic': + net = build_discriminator_Nlayers( + input, + name=name + '_basic', + d_nlayers=3, + d_base_dims=cfg.d_base_dims, + norm_type=cfg.norm_type) + elif cfg.net_D == 'nlayers': + net = build_discriminator_Nlayers( + input, + name=name + '_nlayers', + d_nlayers=cfg.d_nlayers, + d_base_dims=cfg.d_base_dims, + norm_type=cfg.norm_type) + elif cfg.net_D == 'pixel': + net = build_discriminator_Pixel( + input, + name=name + '_pixel', + d_base_dims=cfg.d_base_dims, + norm_type=cfg.norm_type) + else: + raise NotImplementedError( + 'network D: [%s] is wrong format, please check it' % cfg.net_D) + return net + + +def build_resnet_block(inputres, + dim, + name="resnet", + use_bias=False, + use_dropout=False, + norm_type='batch_norm'): + out_res = fluid.layers.pad2d(inputres, [1, 1, 1, 1], mode="reflect") + out_res = conv2d( + out_res, + dim, + 3, + 1, + 0.02, + name=name + "_c1", + norm=norm_type, + activation_fn='relu', + use_bias=use_bias) + + if use_dropout: + out_res = fluid.layers.dropout(out_res, dropout_prob=0.5) + + out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect") + out_res = conv2d( + out_res, + dim, + 3, + 1, + 0.02, + name=name + "_c2", + norm=norm_type, + use_bias=use_bias) + return out_res + inputres + + +def build_generator_resnet_blocks(inputgen, + name="generator", + n_gen_res=9, + g_base_dims=64, + use_dropout=False, + norm_type='batch_norm'): + ''' generator use resnet block''' + '''The shape of input should be equal to the shape of output.''' + use_bias = norm_type == 'instance_norm' + pad_input = fluid.layers.pad2d(inputgen, [3, 3, 3, 3], mode="reflect") + o_c1 = conv2d( + pad_input, + g_base_dims, + 7, + 1, + 0.02, + name=name + "_c1", + norm=norm_type, + activation_fn='relu') + o_c2 = conv2d( + o_c1, + g_base_dims * 2, + 3, + 2, + 0.02, + 1, + name=name + "_c2", + norm=norm_type, + activation_fn='relu') + res_input = conv2d( + o_c2, + g_base_dims * 4, + 3, + 2, + 0.02, + 1, + name=name + "_c3", + norm=norm_type, + activation_fn='relu') + for i in xrange(n_gen_res): + conv_name = name + "_r{}".format(i + 1) + res_output = build_resnet_block( + res_input, + g_base_dims * 4, + name=conv_name, + use_bias=use_bias, + use_dropout=use_dropout) + res_input = res_output + + o_c4 = deconv2d( + res_output, + g_base_dims * 2, + 3, + 2, + 0.02, [1, 1], [0, 1, 0, 1], + name=name + "_c4", + norm=norm_type, + activation_fn='relu') + o_c5 = deconv2d( + o_c4, + g_base_dims, + 3, + 2, + 0.02, [1, 1], [0, 1, 0, 1], + name=name + "_c5", + norm=norm_type, + activation_fn='relu') + o_p2 = fluid.layers.pad2d(o_c5, [3, 3, 3, 3], mode="reflect") + o_c6 = conv2d( + o_p2, + 3, + 7, + 1, + 0.02, + name=name + "_c6", + activation_fn='tanh', + use_bias=True) + + return o_c6 + + +def Unet_block(inputunet, + i, + outer_dim, + inner_dim, + num_downsample, + innermost=False, + outermost=False, + norm_type='batch_norm', + use_bias=False, + use_dropout=False, + name=None): + if outermost == True: + downconv = conv2d( + inputunet, + inner_dim, + 4, + 2, + 0.02, + 1, + name=name + '_outermost_dc1', + use_bias=True) + i += 1 + mid_block = Unet_block( + downconv, + i, + inner_dim, + inner_dim * 2, + num_downsample, + norm_type=norm_type, + use_bias=use_bias, + use_dropout=use_dropout, + name=name) + uprelu = fluid.layers.relu(mid_block, name=name + '_outermost_relu') + updeconv = deconv2d( + uprelu, + outer_dim, + 4, + 2, + 0.02, + 1, + name=name + '_outermost_uc1', + activation_fn='tanh', + use_bias=use_bias) + return updeconv + elif innermost == True: + downrelu = fluid.layers.leaky_relu( + inputunet, 0.2, name=name + '_innermost_leaky_relu') + upconv = conv2d( + downrelu, + inner_dim, + 4, + 2, + 0.02, + 1, + name=name + '_innermost_dc1', + activation_fn='relu', + use_bias=use_bias) + updeconv = deconv2d( + upconv, + outer_dim, + 4, + 2, + 0.02, + 1, + name=name + '_innermost_uc1', + norm=norm_type, + use_bias=use_bias) + return fluid.layers.concat([inputunet, updeconv], 1) + else: + downrelu = fluid.layers.leaky_relu( + inputunet, 0.2, name=name + '_leaky_relu') + downnorm = conv2d( + downrelu, + inner_dim, + 4, + 2, + 0.02, + 1, + name=name + 'dc1', + norm=norm_type, + use_bias=use_bias) + i += 1 + if i < 4: + mid_block = Unet_block( + downnorm, + i, + inner_dim, + inner_dim * 2, + num_downsample, + norm_type=norm_type, + use_bias=use_bias, + name=name + '_mid{}'.format(i)) + elif i < num_downsample - 1: + mid_block = Unet_block( + downnorm, + i, + inner_dim, + inner_dim, + num_downsample, + norm_type=norm_type, + use_bias=use_bias, + use_dropout=use_dropout, + name=name + '_mid{}'.format(i)) + else: + mid_block = Unet_block( + downnorm, + i, + inner_dim, + inner_dim, + num_downsample, + innermost=True, + norm_type=norm_type, + use_bias=use_bias, + name=name + '_innermost') + uprelu = fluid.layers.relu(mid_block, name=name + '_relu') + updeconv = deconv2d( + uprelu, + outer_dim, + 4, + 2, + 0.02, + 1, + name=name + '_uc1', + norm=norm_type, + use_bias=use_bias) + + if use_dropout: + upnorm = fluid.layers.dropout(upnorm, dropout_prob=0.5) + return fluid.layers.concat([inputunet, updeconv], 1) + + +def UnetSkipConnectionBlock(input, + i, + num_downs, + outer_nc, + inner_nc, + outermost=False, + innermost=False, + norm='batch_norm', + use_dropout=False, + name=""): + use_bias = norm == "instance" + if outermost: + downconv = conv2d( + input, + inner_nc, + 4, + 2, + padding=1, + use_bias=use_bias, + name=name + '_down_conv') + i += 1 + ngf = inner_nc + sub_res = UnetSkipConnectionBlock( + downconv, + i, + num_downs, + outer_nc=ngf, + inner_nc=ngf * 2, + norm=norm, + name=name + '_u%d' % i) + uprelu = fluid.layers.relu(sub_res) + upconv = deconv2d( + uprelu, + outer_nc, + 4, + 2, + padding=1, + activation_fn='tanh', + name=name + '_up_conv') + return upconv + elif innermost: + downrelu = fluid.layers.leaky_relu(input, 0.2) + downconv = conv2d( + downrelu, + inner_nc, + 4, + 2, + padding=1, + use_bias=use_bias, + name=name + '_down_conv') + uprelu = fluid.layers.relu(downconv) + upconv = deconv2d( + uprelu, + outer_nc, + 4, + 2, + padding=1, + use_bias=use_bias, + norm=norm, + name=name + '_up_conv') + return fluid.layers.concat([input, upconv], 1) + else: + downrelu = fluid.layers.leaky_relu(input, 0.2) + downconv = conv2d( + downrelu, + inner_nc, + 4, + 2, + padding=1, + use_bias=use_bias, + norm=norm, + name=name + '_down_conv') + i += 1 + ngf = inner_nc + if i < 4: + sub_res = UnetSkipConnectionBlock( + downconv, + i, + num_downs, + outer_nc=ngf, + inner_nc=ngf * 2, + norm=norm, + name=name + '_u%d' % i) + elif i < num_downs - 1: + sub_res = UnetSkipConnectionBlock( + downconv, + i, + num_downs, + outer_nc=ngf, + inner_nc=ngf, + norm=norm, + name=name + '_u%d' % i) + + else: + sub_res = UnetSkipConnectionBlock( + downconv, + i, + num_downs, + outer_nc=ngf, + inner_nc=ngf, + innermost=True, + norm=norm, + name=name + '_u%d' % i) + + uprelu = fluid.layers.relu(sub_res) + upconv = deconv2d( + uprelu, + outer_nc, + 4, + 2, + padding=1, + use_bias=use_bias, + norm=norm, + name=name + '_up_conv') + out = upconv + if use_dropout: + out = fluid.layers.dropout(out, 0.5) + return fluid.layers.concat([input, out], 1) + + +def build_generator_Unet(input, + name="", + num_downsample=8, + g_base_dims=64, + use_dropout=False, + norm_type='batch_norm'): + ''' generator use Unet''' + i = 0 + output = UnetSkipConnectionBlock( + input, + i, + num_downsample, + 3, + g_base_dims, + outermost=True, + norm=norm_type, + name=name + '_u%d' % i) + + return output + + +def build_discriminator_Nlayers(inputdisc, + name="discriminator", + d_nlayers=3, + d_base_dims=64, + norm_type='batch_norm'): + use_bias = norm_type != 'batch_norm' + dis_input = conv2d( + inputdisc, + d_base_dims, + 4, + 2, + 0.02, + 1, + name=name + "_c1", + activation_fn='leaky_relu', + relufactor=0.2, + use_bias=True) + d_dims = d_base_dims + for i in xrange(d_nlayers - 1): + conv_name = name + "_c{}".format(i + 2) + d_dims *= 2 + dis_output = conv2d( + dis_input, + d_dims, + 4, + 2, + 0.02, + 1, + name=conv_name, + norm=norm_type, + activation_fn='leaky_relu', + relufactor=0.2, + use_bias=use_bias) + dis_input = dis_output + last_dims = min(2**d_nlayers, 8) + o_c4 = conv2d( + dis_output, + d_base_dims * last_dims, + 4, + 1, + 0.02, + 1, + name + "_c{}".format(d_nlayers + 1), + norm=norm_type, + activation_fn='leaky_relu', + relufactor=0.2, + use_bias=use_bias) + o_c5 = conv2d( + o_c4, + 1, + 4, + 1, + 0.02, + 1, + name + "_c{}".format(d_nlayers + 2), + use_bias=True) + return o_c5 + + +def build_discriminator_Pixel(inputdisc, + name="discriminator", + d_base_dims=64, + norm_type='batch_norm'): + use_bias = norm_type != 'instance_norm' + o_c1 = conv2d( + inputdisc, + d_base_dims, + 1, + 1, + 0.02, + name=name + '_c1', + activation_fn='leaky_relu', + relufactor=0.2, + use_bias=True) + o_c2 = conv2d( + o_c1, + d_base_dims * 2, + 1, + 1, + 0.02, + name=name + '_c2', + norm=norm_type, + activation_fn='leaky_relu', + relufactor=0.2, + use_bias=use_bias) + o_c3 = conv2d(o_c2, 1, 1, 1, 0.02, name=name + '_c3', use_bias=use_bias) + return o_c3 diff --git a/PaddleCV/gan/scripts/infer_pix2pix.sh b/PaddleCV/gan/scripts/infer_pix2pix.sh new file mode 100644 index 00000000..2e5d484b --- /dev/null +++ b/PaddleCV/gan/scripts/infer_pix2pix.sh @@ -0,0 +1 @@ +python infer.py --init_model output/chechpoints/15/ --input data/cityscapes/test/B/100.jpg --model_net Pix2pix --net_G unet_256 diff --git a/PaddleCV/gan/scripts/run_pix2pix.sh b/PaddleCV/gan/scripts/run_pix2pix.sh new file mode 100644 index 00000000..4754888f --- /dev/null +++ b/PaddleCV/gan/scripts/run_pix2pix.sh @@ -0,0 +1 @@ +python train.py --model_net Pix2pix --dataset cityscapes --train_list data/cityscapes/pix2pix_train_list --test_list data/cityscapes/pix2pix_test_list10 --crop_type Random --dropout True --gan_mode vanilla --batch_size 1 > log_out 2>log_err diff --git a/PaddleCV/gan/train.py b/PaddleCV/gan/train.py index ba6f3a0a..eb9857c9 100644 --- a/PaddleCV/gan/train.py +++ b/PaddleCV/gan/train.py @@ -31,6 +31,8 @@ def train(cfg): if cfg.model_net == 'CycleGAN': a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data( ) + elif cfg.model_net == 'Pix2pix': + train_reader, test_reader, batch_num = reader.make_data() else: if cfg.dataset == 'mnist': train_reader = reader.make_data() @@ -51,6 +53,9 @@ def train(cfg): from trainer.CycleGAN import CycleGAN model = CycleGAN(cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num) + elif cfg.model_net == 'Pix2pix': + from trainer.Pix2pix import Pix2pix + model = Pix2pix(cfg, train_reader, test_reader, batch_num) else: pass diff --git a/PaddleCV/gan/trainer/Pix2pix.py b/PaddleCV/gan/trainer/Pix2pix.py new file mode 100644 index 00000000..e965ab4d --- /dev/null +++ b/PaddleCV/gan/trainer/Pix2pix.py @@ -0,0 +1,285 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from network.Pix2pix_network import Pix2pix_model +from util import utility +import paddle.fluid as fluid +import sys +import time + + +class GTrainer(): + def __init__(self, input_A, input_B, cfg, step_per_epoch): + self.program = fluid.default_main_program().clone() + with fluid.program_guard(self.program): + model = Pix2pix_model() + self.fake_B = model.network_G(input_A, "generator", cfg=cfg) + self.fake_B.persistable = True + self.infer_program = self.program.clone() + AB = fluid.layers.concat([input_A, self.fake_B], 1) + self.pred = model.network_D(AB, "discriminator", cfg) + if cfg.gan_mode == "lsgan": + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred, + shape=self.pred.shape, + value=1, + dtype='float32') + self.g_loss_gan = fluid.layers.reduce_mean( + fluid.layers.square( + fluid.layers.elementwise_sub( + x=self.pred, y=ones))) + elif cfg.gan_mode == "vanilla": + pred_shape = self.pred.shape + self.pred = fluid.layers.reshape( + self.pred, + [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]], + inplace=True) + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred, + shape=self.pred.shape, + value=1, + dtype='float32') + self.g_loss_gan = fluid.layers.mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=self.pred, label=ones)) + + self.g_loss_L1 = fluid.layers.reduce_mean( + fluid.layers.abs( + fluid.layers.elementwise_sub( + x=input_B, y=self.fake_B))) * cfg.lambda_L1 + self.g_loss = fluid.layers.elementwise_add(self.g_loss_L1, + self.g_loss_gan) + lr = cfg.learning_rate + vars = [] + for var in self.program.list_vars(): + if fluid.io.is_parameter(var) and var.name.startswith( + "generator"): + vars.append(var.name) + self.param = vars + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[99 * step_per_epoch] + + [x * step_per_epoch for x in range(100, cfg.epoch - 1)], + values=[lr] + [ + lr * (1.0 - (x - 99.0) / 101.0) + for x in range(100, cfg.epoch) + ]), + beta1=0.5, + beta2=0.999, + name="net_G") + optimizer.minimize(self.g_loss, parameter_list=vars) + + +class DTrainer(): + def __init__(self, input_A, input_B, fake_B, cfg, step_per_epoch): + self.program = fluid.default_main_program().clone() + lr = cfg.learning_rate + with fluid.program_guard(self.program): + model = Pix2pix_model() + self.real_AB = fluid.layers.concat([input_A, input_B], 1) + self.fake_AB = fluid.layers.concat([input_A, fake_B], 1) + self.pred_real = model.network_D( + self.real_AB, "discriminator", cfg=cfg) + self.pred_fake = model.network_D( + self.fake_AB, "discriminator", cfg=cfg) + if cfg.gan_mode == "lsgan": + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred_real, + shape=self.pred_real.shape, + value=1, + dtype='float32') + self.d_loss_real = fluid.layers.reduce_mean( + fluid.layers.square( + fluid.layers.elementwise_sub( + x=self.pred_real, y=ones))) + self.d_loss_fake = fluid.layers.reduce_mean( + fluid.layers.square(x=self.pred_fake)) + elif cfg.gan_mode == "vanilla": + pred_shape = self.pred_real.shape + self.pred_real = fluid.layers.reshape( + self.pred_real, + [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]], + inplace=True) + self.pred_fake = fluid.layers.reshape( + self.pred_fake, + [-1, pred_shape[1] * pred_shape[2] * pred_shape[3]], + inplace=True) + zeros = fluid.layers.fill_constant_batch_size_like( + input=self.pred_fake, + shape=self.pred_fake.shape, + value=0, + dtype='float32') + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred_real, + shape=self.pred_real.shape, + value=1, + dtype='float32') + self.d_loss_real = fluid.layers.mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=self.pred_real, label=ones)) + self.d_loss_fake = fluid.layers.mean( + fluid.layers.sigmoid_cross_entropy_with_logits( + x=self.pred_fake, label=zeros)) + self.d_loss = 0.5 * (self.d_loss_real + self.d_loss_fake) + vars = [] + for var in self.program.list_vars(): + if fluid.io.is_parameter(var) and var.name.startswith( + "discriminator"): + vars.append(var.name) + + self.param = vars + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[99 * step_per_epoch] + + [x * step_per_epoch for x in range(100, cfg.epoch - 1)], + values=[lr] + [ + lr * (1.0 - (x - 99.0) / 101.0) + for x in range(100, cfg.epoch) + ]), + beta1=0.5, + beta2=0.999, + name="net_D") + + optimizer.minimize(self.d_loss, parameter_list=vars) + + +class Pix2pix(object): + def add_special_args(self, parser): + parser.add_argument( + '--net_G', + type=str, + default="unet_256", + help="Choose the Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]" + ) + parser.add_argument( + '--net_D', + type=str, + default="basic", + help="Choose the Pix2pix discriminator's network, choose in [basic|nlayers|pixel]" + ) + parser.add_argument( + '--d_nlayers', + type=int, + default=3, + help="only used when Pix2pix discriminator is nlayers") + + return parser + + def __init__(self, + cfg=None, + train_reader=None, + test_reader=None, + batch_num=1): + self.cfg = cfg + self.train_reader = train_reader + self.test_reader = test_reader + self.batch_num = batch_num + + def build_model(self): + data_shape = [-1, 3, self.cfg.crop_size, self.cfg.crop_size] + + input_A = fluid.layers.data( + name='input_A', shape=data_shape, dtype='float32') + input_B = fluid.layers.data( + name='input_B', shape=data_shape, dtype='float32') + input_fake = fluid.layers.data( + name='input_fake', shape=data_shape, dtype='float32') + + gen_trainer = GTrainer(input_A, input_B, self.cfg, self.batch_num) + dis_trainer = DTrainer(input_A, input_B, input_fake, self.cfg, + self.batch_num) + + # prepare environment + place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if self.cfg.init_model: + utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G") + utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D") + + ### memory optim + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + + gen_trainer_program = fluid.CompiledProgram( + gen_trainer.program).with_data_parallel( + loss_name=gen_trainer.g_loss.name, + build_strategy=build_strategy) + dis_trainer_program = fluid.CompiledProgram( + dis_trainer.program).with_data_parallel( + loss_name=dis_trainer.d_loss.name, + build_strategy=build_strategy) + + t_time = 0 + + for epoch_id in range(self.cfg.epoch): + batch_id = 0 + for i in range(self.batch_num): + data_A, data_B = next(self.train_reader()) + tensor_A = fluid.LoDTensor() + tensor_B = fluid.LoDTensor() + tensor_A.set(data_A, place) + tensor_B.set(data_B, place) + s_time = time.time() + # optimize the generator network + g_loss_gan, g_loss_l1, fake_B_tmp = exe.run( + gen_trainer_program, + fetch_list=[ + gen_trainer.g_loss_gan, gen_trainer.g_loss_L1, + gen_trainer.fake_B + ], + feed={"input_A": tensor_A, + "input_B": tensor_B}) + + # optimize the discriminator network + d_loss_real, d_loss_fake = exe.run(dis_trainer_program, + fetch_list=[ + dis_trainer.d_loss_real, + dis_trainer.d_loss_fake + ], + feed={ + "input_A": tensor_A, + "input_B": tensor_B, + "input_fake": fake_B_tmp + }) + + batch_time = time.time() - s_time + t_time += batch_time + if batch_id % self.cfg.print_freq == 0: + print("epoch{}: batch{}: \n\ + g_loss_gan: {}; g_loss_l1: {}; \n\ + d_loss_real: {}; d_loss_fake: {}; \n\ + Batch_time_cost: {:.2f}" + .format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[ + 0], d_loss_real[0], d_loss_fake[0], batch_time)) + + sys.stdout.flush() + batch_id += 1 + + if self.cfg.run_test: + test_program = gen_trainer.infer_program + utility.save_test_image(epoch_id, self.cfg, exe, place, + test_program, gen_trainer, + self.test_reader) + + if self.cfg.save_checkpoints: + utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, + "net_G") + utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer, + "net_D") diff --git a/PaddleCV/gan/util/config.py b/PaddleCV/gan/util/config.py index 46dc0fbb..8df6b9c6 100644 --- a/PaddleCV/gan/util/config.py +++ b/PaddleCV/gan/util/config.py @@ -71,7 +71,9 @@ def base_parse_args(parser): add_arg('model_net', str, "cgan", "The model used.") add_arg('dataset', str, "mnist", "The dataset used.") add_arg('data_dir', str, "./data", "The dataset root directory") - add_arg('data_list', str, None, "The dataset list file name") + add_arg('data_list', str, "data/cityscapes/pix2pix_train_list", "The data list file name") + add_arg('train_list', str, "data/cityscapes/pix2pix_train_list", "The train list file name") + add_arg('test_list', str, "data/cityscapes/pix2pix_test_list10", "The test list file name") add_arg('batch_size', int, 1, "Minibatch size.") add_arg('epoch', int, 200, "The number of epoch to be trained.") add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator") @@ -85,15 +87,16 @@ def base_parse_args(parser): add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('profile', bool, False, "Whether to profile.") add_arg('dropout', bool, False, "Whether to use drouput.") - add_arg('use_dropout', bool, False, "Whether to use dropout") add_arg('drop_last', bool, False, "Whether to drop the last images that cannot form a batch") add_arg('shuffle', bool, True, "Whether to shuffle data") add_arg('output', str, "./output", "The directory the model and the test result to be saved to.") add_arg('init_model', str, None, "The init model file of directory.") + add_arg('gan_mode', str, "vanilla", "The init model file of directory.") add_arg('norm_type', str, "batch_norm", "Which normalization to used") - add_arg('learning_rate', int, 0.0002, "the initialize learning rate") + add_arg('learning_rate', float, 0.0002, "the initialize learning rate") + add_arg('lambda_L1', float, 100.0, "the initialize learning rate") add_arg('num_generator_time', int, 1, "the generator run times in training each epoch") add_arg('print_freq', int, 10, "the frequency of print loss") diff --git a/PaddleCV/gan/util/utility.py b/PaddleCV/gan/util/utility.py index 3c2af5bf..00d93d98 100644 --- a/PaddleCV/gan/util/utility.py +++ b/PaddleCV/gan/util/utility.py @@ -66,45 +66,75 @@ def init_checkpoints(cfg, exe, trainer, name): sys.stdout.flush() -def save_test_image(epoch, cfg, exe, place, test_program, g_trainer, - A_test_reader, B_test_reader): +def save_test_image(epoch, + cfg, + exe, + place, + test_program, + g_trainer, + A_test_reader, + B_test_reader=None): out_path = cfg.output + '/test' if not os.path.exists(out_path): os.makedirs(out_path) - for data_A, data_B in zip(A_test_reader(), B_test_reader()): - A_name = data_A[0][1] - B_name = data_B[0][1] - tensor_A = fluid.LoDTensor() - tensor_B = fluid.LoDTensor() - tensor_A.set(data_A[0][0], place) - tensor_B.set(data_B[0][0], place) - fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( - test_program, - fetch_list=[ - g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, - g_trainer.cyc_B - ], - feed={"input_A": tensor_A, - "input_B": tensor_B}) - fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) - fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) - cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0]) - cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0]) - input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) - input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) + if B_test_reader is None: + for data in zip(A_test_reader()): + data_A, data_B, name = data[0] + name = name[0] + tensor_A = fluid.LoDTensor() + tensor_B = fluid.LoDTensor() + tensor_A.set(data_A, place) + tensor_B.set(data_B, place) + fake_B_temp = exe.run( + test_program, + fetch_list=[g_trainer.fake_B], + feed={"input_A": tensor_A, + "input_B": tensor_B}) + fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) + input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0]) + input_B_temp = np.squeeze(data_A[0]).transpose([1, 2, 0]) - imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( - (fake_B_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( - (fake_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( - (cyc_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( - (cyc_B_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( - (input_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( - (input_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/fakeB_" + str(epoch) + "_" + name, ( + (fake_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputA_" + str(epoch) + "_" + name, ( + (input_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputB_" + str(epoch) + "_" + name, ( + (input_B_temp + 1) * 127.5).astype(np.uint8)) + else: + for data_A, data_B in zip(A_test_reader(), B_test_reader()): + A_name = data_A[0][1] + B_name = data_B[0][1] + tensor_A = fluid.LoDTensor() + tensor_B = fluid.LoDTensor() + tensor_A.set(data_A[0][0], place) + tensor_B.set(data_B[0][0], place) + fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( + test_program, + fetch_list=[ + g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, + g_trainer.cyc_B + ], + feed={"input_A": tensor_A, + "input_B": tensor_B}) + fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) + fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) + cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0]) + cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0]) + input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) + input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) + + imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( + (fake_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( + (fake_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( + (cyc_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( + (cyc_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( + (input_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( + (input_B_temp + 1) * 127.5).astype(np.uint8)) class ImagePool(object): -- GitLab