diff --git a/PaddleCV/gan/data_reader.py b/PaddleCV/gan/data_reader.py index 6515d48d09494efd1c258e1c02916301d9b5d7c0..159e29a9cdf4f547de4d180bb946b1be4f0f655a 100644 --- a/PaddleCV/gan/data_reader.py +++ b/PaddleCV/gan/data_reader.py @@ -42,7 +42,7 @@ def CentorCrop(img, crop_w, crop_h): def RandomHorizonFlip(img): i = np.random.rand() if i > 0.5: - img = ImageOps.mirror(image) + img = ImageOps.mirror(img) return img @@ -283,13 +283,21 @@ class celeba_reader_creator(reader_creator): if shuffle: np.random.shuffle(self.images) for file, label in self.images: - img = Image.open(os.path.join(self.image_dir, - file)).convert('RGB') - label = np.array(label).astype("float32") - label = (label + 1) // 2 - img = CentorCrop(img, args.crop_size, args.crop_size) - img = img.resize((args.load_size, args.load_size), - Image.BILINEAR) + if args.model_net == "StarGAN": + img = Image.open(os.path.join(self.image_dir, file)) + label = np.array(label).astype("float32") + img = RandomHorizonFlip(img) + img = CentorCrop(img, args.crop_size, args.crop_size) + img = img.resize((args.image_size, args.image_size), + Image.BILINEAR) + else: + img = Image.open(os.path.join(self.image_dir, + file)).convert('RGB') + label = np.array(label).astype("float32") + label = (label + 1) // 2 + img = CentorCrop(img, args.crop_size, args.crop_size) + img = img.resize((args.load_size, args.load_size), + Image.BILINEAR) img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5 img = img.transpose([2, 0, 1]) @@ -310,12 +318,19 @@ class celeba_reader_creator(reader_creator): batch_out_2 = [] batch_out_3 = [] for file, label in self.images: - img = Image.open(os.path.join(self.image_dir, file)).convert( - 'RGB') - label = np.array(label).astype("float32") - img = CentorCrop(img, 170, 170) - img = img.resize((args.image_size, args.image_size), - Image.BILINEAR) + if args.model_net == 'StarGAN': + img = Image.open(os.path.join(self.image_dir, file)) + label = np.array(label).astype("float32") + img = CentorCrop(img, args.crop_size, args.crop_size) + img = img.resize((args.image_size, args.image_size), + Image.BILINEAR) + else: + img = Image.open(os.path.join(self.image_dir, + file)).convert('RGB') + label = np.array(label).astype("float32") + img = CentorCrop(img, 170, 170) + img = img.resize((args.image_size, args.image_size), + Image.BILINEAR) img = (np.array(img).astype('float32') / 255.0 - 0.5) / 0.5 img = img.transpose([2, 0, 1]) if return_name: @@ -482,7 +497,7 @@ class data_reader(object): self.cfg, shuffle=self.shuffle) return reader, reader_test, batch_num - else: + elif self.cfg.model_net == 'Pix2pix': dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) train_list = os.path.join(dataset_dir, 'train.txt') if self.cfg.train_list is not None: diff --git a/PaddleCV/gan/infer.py b/PaddleCV/gan/infer.py index 3dda44c644da6d32a9277d340113d9ccc3c8ac23..81d6821d7667fba1c8655f458d15cfecda36c323 100644 --- a/PaddleCV/gan/infer.py +++ b/PaddleCV/gan/infer.py @@ -78,7 +78,10 @@ def infer(args): from network.Pix2pix_network import Pix2pix_model model = Pix2pix_model() fake = model.network_G(input, "generator", cfg=args) - + elif args.model_net == 'StarGAN': + from network.StarGAN_network import StarGAN_model + model = StarGAN_model() + fake = model.network_G(input, label_trg_, name="g_main", cfg=args) elif args.model_net == 'STGAN': from network.STGAN_network import STGAN_model model = STGAN_model() @@ -152,6 +155,37 @@ def infer(args): images_concat = np.concatenate(images_concat, 1) imsave(args.output + "/fake_img_" + name[0], ( (images_concat + 1) * 127.5).astype(np.uint8)) + elif args.model_net == 'StarGAN': + test_reader = celeba_reader_creator( + image_dir=args.dataset_dir, + list_filename=args.test_list, + batch_size=args.batch_size, + drop_last=False, + args=args) + reader_test = test_reader.get_test_reader( + args, shuffle=False, return_name=True) + for data in zip(reader_test()): + real_img, label_org, name = data[0] + tensor_img = fluid.LoDTensor() + tensor_label_org = fluid.LoDTensor() + tensor_img.set(real_img, place) + tensor_label_org.set(label_org, place) + real_img_temp = np.squeeze(real_img).transpose([1, 2, 0]) + images = [real_img_temp] + for i in range(cfg.c_dim): + label_trg = np.zeros([1, cfg.c_dim]).astype("float32") + label_trg[0][i] = 1 + tensor_label_trg = fluid.LoDTensor() + tensor_label_trg.set(label_trg, place) + out = exe.run( + feed={"input": tensor_img, + "label_trg_": tensor_label_trg}, + fetch_list=fake.name) + fake_temp = np.squeeze(out[0]).transpose([1, 2, 0]) + images.append(fake_temp) + images_concat = np.concatenate(images, 1) + imsave(out_path + "/fake_img" + str(epoch) + "_" + name[0], ( + (images_concat + 1) * 127.5).astype(np.uint8)) elif args.model_net == 'Pix2pix' or args.model_net == 'cyclegan': for file in glob.glob(args.input): diff --git a/PaddleCV/gan/network/StarGAN_network.py b/PaddleCV/gan/network/StarGAN_network.py new file mode 100644 index 0000000000000000000000000000000000000000..646648e32da1cb1583f02a85de04fcd5749e0b68 --- /dev/null +++ b/PaddleCV/gan/network/StarGAN_network.py @@ -0,0 +1,158 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .base_network import conv2d, deconv2d, norm_layer +import paddle.fluid as fluid +import numpy as np + + +class StarGAN_model(object): + def __init__(self): + pass + + def ResidualBlock(self, input, dim, name): + conv0 = conv2d( + input, + dim, + 3, + 1, + padding=1, + use_bias=False, + norm="instance_norm", + activation_fn='relu', + name=name + ".main0", + initial='kaiming') + conv1 = conv2d( + conv0, + dim, + 3, + 1, + padding=1, + use_bias=False, + norm="instance_norm", + activation_fn=None, + name=name + ".main3", + initial='kaiming') + return input + conv1 + + def network_G(self, input, label_trg, cfg, name="generator"): + repeat_num = 6 + shape = input.shape + label_trg_e = fluid.layers.reshape(label_trg, + [-1, label_trg.shape[1], 1, 1]) + label_trg_e = fluid.layers.expand( + x=label_trg_e, expand_times=[1, 1, shape[2], shape[3]]) + input1 = fluid.layers.concat([input, label_trg_e], 1) + conv0 = conv2d( + input1, + cfg.g_conv_dim, + 7, + 1, + padding=3, + use_bias=False, + norm="instance_norm", + activation_fn='relu', + name=name + '0', + initial='kaiming') + conv_down = conv0 + for i in range(2): + rate = 2**(i + 1) + conv_down = conv2d( + conv_down, + cfg.g_conv_dim * rate, + 4, + 2, + padding=1, + use_bias=False, + norm="instance_norm", + activation_fn='relu', + name=name + str(i * 3 + 3), + initial='kaiming') + res_block = conv_down + for i in range(repeat_num): + res_block = self.ResidualBlock( + res_block, cfg.g_conv_dim * (2**2), name=name + '.%d' % (i + 9)) + deconv = res_block + for i in range(2): + rate = 2**(1 - i) + deconv = deconv2d( + deconv, + cfg.g_conv_dim * rate, + 4, + 2, + padding=1, + use_bias=False, + norm="instance_norm", + activation_fn='relu', + name=name + str(15 + i * 3), + initial='kaiming') + out = conv2d( + deconv, + 3, + 7, + 1, + padding=3, + use_bias=False, + norm=None, + activation_fn='tanh', + name=name + '21', + initial='kaiming') + return out + + def network_D(self, input, cfg, name="discriminator"): + conv0 = conv2d( + input, + cfg.d_conv_dim, + 4, + 2, + padding=1, + activation_fn='leaky_relu', + name=name + '0', + initial='kaiming') + repeat_num = 6 + curr_dim = cfg.d_conv_dim + conv = conv0 + for i in range(1, repeat_num): + curr_dim *= 2 + conv = conv2d( + conv, + curr_dim, + 4, + 2, + padding=1, + activation_fn='leaky_relu', + name=name + str(i * 2), + initial='kaiming') + kernel_size = int(cfg.image_size / np.power(2, repeat_num)) + out1 = conv2d( + conv, + 1, + 3, + 1, + padding=1, + use_bias=False, + name="d_conv1", + initial='kaiming') + out2 = conv2d( + conv, + cfg.c_dim, + kernel_size, + use_bias=False, + name="d_conv2", + initial='kaiming') + return out1, out2 diff --git a/PaddleCV/gan/scripts/run_stargan.sh b/PaddleCV/gan/scripts/run_stargan.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a83d158955dbd1dfb728fc0c560f986d2b06789 --- /dev/null +++ b/PaddleCV/gan/scripts/run_stargan.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=2 python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err +#CUDA_VISIBLE_DEVICES=0 python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./test_list --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 2 --epoch 200 > log_out 2>log_err diff --git a/PaddleCV/gan/train.py b/PaddleCV/gan/train.py index 258e5aa979d82e5514230995d36a7cce2018715c..c73e1207253cbe51555ddb55f70fba7e47b8aed9 100644 --- a/PaddleCV/gan/train.py +++ b/PaddleCV/gan/train.py @@ -33,6 +33,8 @@ def train(cfg): ) elif cfg.model_net == 'Pix2pix': train_reader, test_reader, batch_num = reader.make_data() + elif cfg.model_net == 'StarGAN': + train_reader, test_reader, batch_num = reader.make_data() else: if cfg.dataset == 'mnist': train_reader = reader.make_data() @@ -56,6 +58,9 @@ def train(cfg): elif cfg.model_net == 'Pix2pix': from trainer.Pix2pix import Pix2pix model = Pix2pix(cfg, train_reader, test_reader, batch_num) + elif cfg.model_net == 'StarGAN': + from trainer.StarGAN import StarGAN + model = StarGAN(cfg, train_reader, test_reader, batch_num) elif cfg.model_net == 'AttGAN': from trainer.AttGAN import AttGAN model = AttGAN(cfg, train_reader, test_reader, batch_num) diff --git a/PaddleCV/gan/trainer/StarGAN.py b/PaddleCV/gan/trainer/StarGAN.py new file mode 100644 index 0000000000000000000000000000000000000000..d54c88c49f52408f032dc0bbeed1d14be94bdac3 --- /dev/null +++ b/PaddleCV/gan/trainer/StarGAN.py @@ -0,0 +1,402 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from network.StarGAN_network import StarGAN_model +from util import utility +import paddle.fluid as fluid +import sys +import time +import copy +import numpy as np +import pickle as pkl + + +class GTrainer(): + def __init__(self, image_real, label_org, label_trg, cfg, step_per_epoch): + self.program = fluid.default_main_program().clone() + with fluid.program_guard(self.program): + model = StarGAN_model() + self.fake_img = model.network_G( + image_real, label_trg, cfg, name="g_main") + self.fake_img.persistable = True + self.rec_img = model.network_G( + self.fake_img, label_org, cfg, name="g_main") + self.rec_img.persistable = True + self.infer_program = self.program.clone(for_test=False) + self.g_loss_rec = fluid.layers.reduce_mean( + fluid.layers.abs( + fluid.layers.elementwise_sub( + x=image_real, y=self.rec_img))) + self.pred_fake, self.cls_fake = model.network_D( + self.fake_img, cfg, name="d_main") + #wgan + if cfg.gan_mode == "wgan": + self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake) + #lsgan + elif cfg.gan_mode == "lsgan": + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred_fake, + shape=self.pred_fake.shape, + value=1, + dtype='float32') + self.g_loss_fake = fluid.layers.mean( + fluid.layers.square( + fluid.layers.elementwise_sub( + x=self.pred_fake, y=ones))) + + cls_shape = self.cls_fake.shape + self.cls_fake = fluid.layers.reshape( + self.cls_fake, + [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]]) + self.g_loss_cls = fluid.layers.reduce_sum( + fluid.layers.sigmoid_cross_entropy_with_logits( + self.cls_fake, label_trg)) / cfg.batch_size + self.g_loss = self.g_loss_fake + cfg.lambda_rec * self.g_loss_rec + self.g_loss_cls + self.g_loss_fake.persistable = True + self.g_loss_rec.persistable = True + self.g_loss_cls.persistable = True + lr = cfg.g_lr + vars = [] + for var in self.program.list_vars(): + if fluid.io.is_parameter(var) and var.name.startswith("g_"): + vars.append(var.name) + self.param = vars + total_iters = step_per_epoch * cfg.epoch + boundaries = [cfg.num_iters - cfg.num_iters_decay] + values = [lr] + for x in range(cfg.num_iters - cfg.num_iters_decay + 1, + total_iters): + if x % cfg.lr_update_step == 0: + boundaries.append(x) + lr -= (lr / float(cfg.num_iters_decay)) + values.append(lr) + lr = values[-1] + lr -= (lr / float(cfg.num_iters_decay)) + values.append(lr) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.piecewise_decay( + boundaries=boundaries, values=values), + beta1=0.5, + beta2=0.999, + name="net_G") + optimizer.minimize(self.g_loss, parameter_list=vars) + with open('program_gen.txt', 'w') as f: + print(self.program, file=f) + + +class DTrainer(): + def __init__(self, image_real, label_org, label_trg, cfg, step_per_epoch): + self.program = fluid.default_main_program().clone() + with fluid.program_guard(self.program): + model = StarGAN_model() + clone_image_real = [] + for b in self.program.blocks: + if b.has_var('image_real'): + clone_image_real = b.var('image_real') + break + self.fake_img = model.network_G( + image_real, label_trg, cfg, name="g_main") + self.pred_real, self.cls_real = model.network_D( + image_real, cfg, name="d_main") + self.pred_fake, _ = model.network_D( + self.fake_img, cfg, name="d_main") + cls_shape = self.cls_real.shape + self.cls_real = fluid.layers.reshape( + self.cls_real, + [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]]) + self.d_loss_cls = fluid.layers.reduce_sum( + fluid.layers.sigmoid_cross_entropy_with_logits( + self.cls_real, label_org)) / cfg.batch_size + #wgan + if cfg.gan_mode == "wgan": + self.d_loss_fake = fluid.layers.mean(self.pred_fake) + self.d_loss_real = -1 * fluid.layers.mean(self.pred_real) + self.d_loss_gp = self.gradient_penalty( + getattr(model, "network_D"), + clone_image_real, + self.fake_img, + cfg=cfg, + name="d_main") + self.d_loss = self.d_loss_real + self.d_loss_fake + self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp + #lsgan + elif cfg.gan_mode == "lsgan": + ones = fluid.layers.fill_constant_batch_size_like( + input=self.pred_real, + shape=self.pred_real.shape, + value=1, + dtype='float32') + self.d_loss_real = fluid.layers.mean( + fluid.layers.square( + fluid.layers.elementwise_sub( + x=self.pred_real, y=ones))) + self.d_loss_fake = fluid.layers.mean( + fluid.layers.square(x=self.pred_fake)) + self.d_loss = self.d_loss_real + self.d_loss_fake + cfg.lambda_cls * self.d_loss_cls + + self.d_loss_real.persistable = True + self.d_loss_fake.persistable = True + self.d_loss_gp.persistable = True + self.d_loss_cls.persistable = True + vars = [] + for var in self.program.list_vars(): + if fluid.io.is_parameter(var) and var.name.startswith("d_"): + vars.append(var.name) + + self.param = vars + total_iters = step_per_epoch * cfg.epoch + boundaries = [cfg.num_iters - cfg.num_iters_decay] + values = [cfg.d_lr] + lr = cfg.d_lr + for x in range(cfg.num_iters - cfg.num_iters_decay + 1, + total_iters): + if x % cfg.lr_update_step == 0: + boundaries.append(x) + lr -= (lr / float(cfg.num_iters_decay)) + values.append(lr) + lr = values[-1] + lr -= (lr / float(cfg.num_iters_decay)) + values.append(lr) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.piecewise_decay( + boundaries=boundaries, values=values), + beta1=0.5, + beta2=0.999, + name="net_D") + + optimizer.minimize(self.d_loss, parameter_list=vars) + with open('program_dis.txt', 'w') as f: + print(self.program, file=f) + + def gradient_penalty(self, f, real, fake, cfg=None, name=None): + def _interpolate(a, b): + shape = [a.shape[0]] + alpha = fluid.layers.uniform_random_batch_size_like( + input=a, shape=shape, min=0.0, max=1.0) + a.stop_gradient = True + b.stop_gradient = True + inner1 = fluid.layers.elementwise_mul(a, alpha, axis=0) + inner2 = fluid.layers.elementwise_mul(b, (1.0 - alpha), axis=0) + inner1.stop_gradient = True + inner2.stop_gradient = True + inner = inner1 + inner2 + return inner + + x = _interpolate(real, fake) + pred, _ = f(x, cfg, name=name) + if isinstance(pred, tuple): + pred = pred[0] + vars = [] + for var in fluid.default_main_program().list_vars(): + if fluid.io.is_parameter(var) and var.name.startswith('d_'): + vars.append(var.name) + grad = fluid.gradients(pred, x, no_grad_set=vars) + grad_shape = grad.shape + grad = fluid.layers.reshape( + grad, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]]) + norm = fluid.layers.sqrt( + fluid.layers.reduce_sum( + fluid.layers.square(grad), dim=1)) + gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0)) + return gp + + +class StarGAN(object): + def add_special_args(self, parser): + parser.add_argument( + '--image_size', type=int, default=256, help="image size") + parser.add_argument( + '--g_lr', type=float, default=0.0001, help="learning rate of g") + parser.add_argument( + '--d_lr', type=float, default=0.0001, help="learning rate of d") + parser.add_argument( + '--c_dim', + type=int, + default=5, + help="the number of attributes we selected") + parser.add_argument( + '--g_conv_dim', + type=int, + default=64, + help="base conv dims in generator") + parser.add_argument( + '--d_conv_dim', + type=int, + default=64, + help="base conv dims in discriminator") + parser.add_argument( + '--g_repeat_num', + type=int, + default=6, + help="number of layers in generator") + parser.add_argument( + '--d_repeat_num', + type=int, + default=6, + help="number of layers in discriminator") + parser.add_argument( + '--num_iters', type=int, default=200000, help="num iters") + parser.add_argument( + '--num_iters_decay', + type=int, + default=100000, + help="num iters decay") + parser.add_argument( + '--lr_update_step', + type=int, + default=1000, + help="iters when lr update ") + parser.add_argument( + '--lambda_cls', + type=float, + default=1.0, + help="the coefficient of classification") + parser.add_argument( + '--lambda_rec', + type=float, + default=10.0, + help="the coefficient of refactor") + parser.add_argument( + '--lambda_gp', + type=float, + default=10.0, + help="the coefficient of gradient penalty") + parser.add_argument( + '--n_critic', + type=int, + default=5, + help="discriminator training steps when generator update") + parser.add_argument( + '--selected_attrs', + type=str, + default="Black_Hair,Blond_Hair,Brown_Hair,Male,Young", + help="the attributes we selected to change") + parser.add_argument( + '--n_samples', type=int, default=1, help="batch size when testing") + + return parser + + def __init__(self, + cfg=None, + train_reader=None, + test_reader=None, + batch_num=1): + self.cfg = cfg + self.train_reader = train_reader + self.test_reader = test_reader + self.batch_num = batch_num + + def build_model(self): + data_shape = [-1, 3, self.cfg.image_size, self.cfg.image_size] + + image_real = fluid.layers.data( + name='image_real', shape=data_shape, dtype='float32') + label_org = fluid.layers.data( + name='label_org', shape=[self.cfg.c_dim], dtype='float32') + label_trg = fluid.layers.data( + name='label_trg', shape=[self.cfg.c_dim], dtype='float32') + gen_trainer = GTrainer(image_real, label_org, label_trg, self.cfg, + self.batch_num) + dis_trainer = DTrainer(image_real, label_org, label_trg, self.cfg, + self.batch_num) + + # prepare environment + place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + with open('program.txt', "w") as f: + print(gen_trainer.program, file=f) + + if self.cfg.init_model: + utility.init_checkpoints(self.cfg, exe, gen_trainer, "net_G") + utility.init_checkpoints(self.cfg, exe, dis_trainer, "net_D") + + ### memory optim + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + + gen_trainer_program = fluid.CompiledProgram( + gen_trainer.program).with_data_parallel( + loss_name=gen_trainer.g_loss.name, + build_strategy=build_strategy) + dis_trainer_program = fluid.CompiledProgram( + dis_trainer.program).with_data_parallel( + loss_name=dis_trainer.d_loss.name, + build_strategy=build_strategy) + + #losses = [[], []] + t_time = 0 + + test_program = gen_trainer.infer_program + utility.save_test_image(0, self.cfg, exe, place, test_program, + gen_trainer, self.test_reader) + for epoch_id in range(self.cfg.epoch): + batch_id = 0 + for i in range(self.batch_num): + image, label_org = next(self.train_reader()) + label_trg = copy.deepcopy(label_org) + np.random.shuffle(label_trg) + + tensor_img = fluid.LoDTensor() + tensor_label_org = fluid.LoDTensor() + tensor_label_trg = fluid.LoDTensor() + tensor_img.set(image, place) + tensor_label_org.set(label_org, place) + tensor_label_trg.set(label_trg, place) + s_time = time.time() + # optimize the discriminator network + d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run( + dis_trainer_program, + fetch_list=[ + dis_trainer.d_loss_real, dis_trainer.d_loss_fake, + dis_trainer.d_loss, dis_trainer.d_loss_cls, + dis_trainer.d_loss_gp + ], + feed={ + "image_real": tensor_img, + "label_org": tensor_label_org, + "label_trg": tensor_label_trg + }) + # optimize the generator network + if (batch_id + 1) % self.cfg.n_critic == 0: + g_loss_fake, g_loss_rec, g_loss_cls, fake_img, rec_img = exe.run( + gen_trainer_program, + fetch_list=[ + gen_trainer.g_loss_fake, gen_trainer.g_loss_rec, + gen_trainer.g_loss_cls, gen_trainer.fake_img, + gen_trainer.rec_img + ], + feed={ + "image_real": tensor_img, + "label_org": tensor_label_org, + "label_trg": tensor_label_trg + }) + print("epoch{}: batch{}: \n\ + g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}" + .format(epoch_id, batch_id, g_loss_fake[0], + g_loss_rec[0], g_loss_cls[0])) + + batch_time = time.time() - s_time + t_time += batch_time + if batch_id % self.cfg.print_freq == 0: + print("epoch{}: batch{}: \n\ + d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\ + Batch_time_cost: {:.2f}".format( + epoch_id, batch_id, d_loss_real[0], d_loss_fake[ + 0], d_loss_cls[0], d_loss_gp[0], batch_time)) + + sys.stdout.flush() + batch_id += 1 + + if self.cfg.run_test: + test_program = gen_trainer.infer_program + utility.save_test_image(epoch_id, self.cfg, exe, place, + test_program, gen_trainer, + self.test_reader) + + if self.cfg.save_checkpoints: + utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, + "net_G") + utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer, + "net_D") diff --git a/PaddleCV/gan/util/utility.py b/PaddleCV/gan/util/utility.py index 933c63c034dc7049bb4dadadcdff44b3da8141d2..b9242bf9ccc58170f6bcd8e6a31af8e274116e0b 100644 --- a/PaddleCV/gan/util/utility.py +++ b/PaddleCV/gan/util/utility.py @@ -77,7 +77,7 @@ def save_test_image(epoch, out_path = cfg.output + '/test' if not os.path.exists(out_path): os.makedirs(out_path) - if B_test_reader is None: + if cfg.model_net == "Pix2pix": for data in zip(A_test_reader()): data_A, data_B, name = data[0] name = name[0] @@ -100,85 +100,112 @@ def save_test_image(epoch, (input_A_temp + 1) * 127.5).astype(np.uint8)) imsave(out_path + "/inputB_" + str(epoch) + "_" + name, ( (input_B_temp + 1) * 127.5).astype(np.uint8)) - else: - if cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': - for data in zip(A_test_reader()): - real_img, label_org, name = data[0] - label_trg = copy.deepcopy(label_org) - tensor_img = fluid.LoDTensor() - tensor_label_org = fluid.LoDTensor() + elif cfg.model_net == "StarGAN": + for data in zip(A_test_reader()): + real_img, label_org, name = data[0] + tensor_img = fluid.LoDTensor() + tensor_label_org = fluid.LoDTensor() + tensor_img.set(real_img, place) + tensor_label_org.set(label_org, place) + real_img_temp = np.squeeze(real_img).transpose([1, 2, 0]) + images = [real_img_temp] + for i in range(cfg.c_dim): + label_trg = np.zeros([1, cfg.c_dim]).astype("float32") + label_trg[0][i] = 1 tensor_label_trg = fluid.LoDTensor() - tensor_label_org_ = fluid.LoDTensor() - tensor_label_trg_ = fluid.LoDTensor() - tensor_img.set(real_img, place) - tensor_label_org.set(label_org, place) - real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1]) - images = [real_img_temp] - for i in range(cfg.c_dim): - label_trg_tmp = copy.deepcopy(label_trg) - - for j in range(len(label_org)): - label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] - - label_trg_ = map(lambda x: ((x * 2) - 1) * 0.5, - label_trg_tmp) - - for j in range(len(label_org)): - label_trg_[j][i] = label_trg_[j][i] * 2.0 - tensor_label_org_.set(label_org, place) - tensor_label_trg.set(label_trg, place) - tensor_label_trg_.set(label_trg_, place) - out = exe.run(test_program, - feed={ - "image_real": tensor_img, - "label_org": tensor_label_org, - "label_org_": tensor_label_org_, - "label_trg": tensor_label_trg, - "label_trg_": tensor_label_trg_ - }, - fetch_list=[g_trainer.fake_img]) - fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1]) - images.append(fake_temp) - images_concat = np.concatenate(images, 1) - images_concat = np.concatenate(images_concat, 1) - imsave(out_path + "/fake_img" + str(epoch) + '_' + name[0], ( - (images_concat + 1) * 127.5).astype(np.uint8)) - - else: - for data_A, data_B in zip(A_test_reader(), B_test_reader()): - A_name = data_A[0][1] - B_name = data_B[0][1] - tensor_A = fluid.LoDTensor() - tensor_B = fluid.LoDTensor() - tensor_A.set(data_A[0][0], place) - tensor_B.set(data_B[0][0], place) - fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( + tensor_label_trg.set(label_trg, place) + fake_temp, rec_temp = exe.run( test_program, - fetch_list=[ - g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, - g_trainer.cyc_B - ], - feed={"input_A": tensor_A, - "input_B": tensor_B}) - fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) - fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) - cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0]) - cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0]) - input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) - input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) - - imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( - (fake_B_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( - (fake_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( - (cyc_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( - (cyc_B_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( - (input_A_temp + 1) * 127.5).astype(np.uint8)) - imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( - (input_B_temp + 1) * 127.5).astype(np.uint8)) + feed={ + "image_real": tensor_img, + "label_org": tensor_label_org, + "label_trg": tensor_label_trg + }, + fetch_list=[g_trainer.fake_img, g_trainer.rec_img]) + fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0]) + rec_temp = np.squeeze(rec_temp[0]).transpose([1, 2, 0]) + images.append(fake_temp) + images.append(rec_temp) + images_concat = np.concatenate(images, 1) + imsave(out_path + "/fake_img" + str(epoch) + "_" + name[0], ( + (images_concat + 1) * 127.5).astype(np.uint8)) + elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': + for data in zip(A_test_reader()): + real_img, label_org, name = data[0] + label_trg = copy.deepcopy(label_org) + tensor_img = fluid.LoDTensor() + tensor_label_org = fluid.LoDTensor() + tensor_label_trg = fluid.LoDTensor() + tensor_label_org_ = fluid.LoDTensor() + tensor_label_trg_ = fluid.LoDTensor() + tensor_img.set(real_img, place) + tensor_label_org.set(label_org, place) + real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1]) + images = [real_img_temp] + for i in range(cfg.c_dim): + label_trg_tmp = copy.deepcopy(label_trg) + + for j in range(len(label_org)): + label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] + + label_trg_ = map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp) + + for j in range(len(label_org)): + label_trg_[j][i] = label_trg_[j][i] * 2.0 + tensor_label_org_.set(label_org, place) + tensor_label_trg.set(label_trg, place) + tensor_label_trg_.set(label_trg_, place) + out = exe.run(test_program, + feed={ + "image_real": tensor_img, + "label_org": tensor_label_org, + "label_org_": tensor_label_org_, + "label_trg": tensor_label_trg, + "label_trg_": tensor_label_trg_ + }, + fetch_list=[g_trainer.fake_img]) + fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1]) + images.append(fake_temp) + images_concat = np.concatenate(images, 1) + images_concat = np.concatenate(images_concat, 1) + imsave(out_path + "/fake_img" + str(epoch) + '_' + name[0], ( + (images_concat + 1) * 127.5).astype(np.uint8)) + + else: + for data_A, data_B in zip(A_test_reader(), B_test_reader()): + A_name = data_A[0][1] + B_name = data_B[0][1] + tensor_A = fluid.LoDTensor() + tensor_B = fluid.LoDTensor() + tensor_A.set(data_A[0][0], place) + tensor_B.set(data_B[0][0], place) + fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run( + test_program, + fetch_list=[ + g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A, + g_trainer.cyc_B + ], + feed={"input_A": tensor_A, + "input_B": tensor_B}) + fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0]) + fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) + cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0]) + cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0]) + input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0]) + input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0]) + + imsave(out_path + "/fakeB_" + str(epoch) + "_" + A_name, ( + (fake_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/fakeA_" + str(epoch) + "_" + B_name, ( + (fake_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/cycA_" + str(epoch) + "_" + A_name, ( + (cyc_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/cycB_" + str(epoch) + "_" + B_name, ( + (cyc_B_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputA_" + str(epoch) + "_" + A_name, ( + (input_A_temp + 1) * 127.5).astype(np.uint8)) + imsave(out_path + "/inputB_" + str(epoch) + "_" + B_name, ( + (input_B_temp + 1) * 127.5).astype(np.uint8)) class ImagePool(object):