From f489ca5de32f56362314a3970351cbbaefe0024c Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Thu, 27 Jun 2019 21:15:29 +0800 Subject: [PATCH] Fix batch norm in infer (#2570) * fix infer and conflict attribute --- PaddleCV/gan/infer.py | 18 +++++++- PaddleCV/gan/network/AttGAN_network.py | 31 ++++++++++---- PaddleCV/gan/network/STGAN_network.py | 58 ++++++++++++++++++-------- PaddleCV/gan/network/base_network.py | 21 ++++++---- PaddleCV/gan/scripts/infer_attgan.sh | 1 + PaddleCV/gan/scripts/infer_stargan.sh | 1 + PaddleCV/gan/scripts/infer_stgan.sh | 1 + PaddleCV/gan/scripts/run_stargan.sh | 2 +- PaddleCV/gan/scripts/run_stgan.sh | 2 +- PaddleCV/gan/util/utility.py | 31 ++++++++++++++ 10 files changed, 128 insertions(+), 38 deletions(-) create mode 100644 PaddleCV/gan/scripts/infer_attgan.sh create mode 100644 PaddleCV/gan/scripts/infer_stargan.sh create mode 100644 PaddleCV/gan/scripts/infer_stgan.sh diff --git a/PaddleCV/gan/infer.py b/PaddleCV/gan/infer.py index bee5fb55..8670a759 100644 --- a/PaddleCV/gan/infer.py +++ b/PaddleCV/gan/infer.py @@ -27,6 +27,7 @@ import imageio import glob from util.config import add_arguments, print_arguments from data_reader import celeba_reader_creator +from util.utility import check_attribute_conflict import copy parser = argparse.ArgumentParser(description=__doc__) @@ -86,12 +87,22 @@ def infer(args): from network.STGAN_network import STGAN_model model = STGAN_model() fake, _ = model.network_G( - input, label_org_, label_trg_, cfg=args, name='net_G') + input, + label_org_, + label_trg_, + cfg=args, + name='generator', + is_test=True) elif args.model_net == 'AttGAN': from network.AttGAN_network import AttGAN_model model = AttGAN_model() fake, _ = model.network_G( - input, label_org_, label_trg_, cfg=args, name='net_G') + input, + label_org_, + label_trg_, + cfg=args, + name='generator', + is_test=True) else: raise NotImplementedError("model_net {} is not support".format( args.model_net)) @@ -122,6 +133,7 @@ def infer(args): args, shuffle=False, return_name=True) for data in zip(reader_test()): real_img, label_org, name = data[0] + attr_names = args.selected_attrs.split(',') print("read {}".format(name)) label_trg = copy.deepcopy(label_org) tensor_img = fluid.LoDTensor() @@ -137,6 +149,8 @@ def infer(args): label_trg_tmp = copy.deepcopy(label_trg) for j in range(len(label_org)): label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] + label_trg_tmp = check_attribute_conflict( + label_trg_tmp, attr_names[i], attr_names) label_trg_ = list( map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) for j in range(len(label_org)): diff --git a/PaddleCV/gan/network/AttGAN_network.py b/PaddleCV/gan/network/AttGAN_network.py index f2ab280f..7bd91218 100755 --- a/PaddleCV/gan/network/AttGAN_network.py +++ b/PaddleCV/gan/network/AttGAN_network.py @@ -27,17 +27,26 @@ class AttGAN_model(object): def __init__(self): pass - def network_G(self, input, label_org, label_trg, cfg, name="generator"): + def network_G(self, + input, + label_org, + label_trg, + cfg, + name="generator", + is_test=False): _a = label_org _b = label_trg z = self.Genc( input, name=name + '_Genc', dim=cfg.g_base_dims, - n_layers=cfg.n_layers) - fake_image = self.Gdec(z, _b, name=name + '_Gdec', dim=cfg.g_base_dims) + n_layers=cfg.n_layers, + is_test=is_test) + fake_image = self.Gdec( + z, _b, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test) - rec_image = self.Gdec(z, _a, name=name + '_Gdec', dim=cfg.g_base_dims) + rec_image = self.Gdec( + z, _a, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test) return fake_image, rec_image def network_D(self, input, cfg, name="discriminator"): @@ -54,7 +63,7 @@ class AttGAN_model(object): z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0) return fluid.layers.concat([z, ones * a], axis=1) - def Genc(self, input, dim=64, n_layers=5, name='G_enc_'): + def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False): z = input zs = [] for i in range(n_layers): @@ -71,7 +80,8 @@ class AttGAN_model(object): name=name + str(i), use_bias=False, relufactor=0.01, - initial='kaiming') + initial='kaiming', + is_test=is_test) zs.append(z) return zs @@ -83,7 +93,8 @@ class AttGAN_model(object): n_layers=5, shortcut_layers=1, inject_layers=1, - name='G_dec_'): + name='G_dec_', + is_test=False): shortcut_layers = min(shortcut_layers, n_layers - 1) inject_layers = min(inject_layers, n_layers - 1) @@ -101,7 +112,8 @@ class AttGAN_model(object): norm='batch_norm', activation_fn='relu', use_bias=False, - initial='kaiming') + initial='kaiming', + is_test=is_test) if shortcut_layers > i: z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1) if inject_layers > i: @@ -116,7 +128,8 @@ class AttGAN_model(object): name=name + str(i), activation_fn='tanh', use_bias=True, - initial='kaiming') + initial='kaiming', + is_test=is_test) return x def D(self, diff --git a/PaddleCV/gan/network/STGAN_network.py b/PaddleCV/gan/network/STGAN_network.py index d7f4132b..4a1111cc 100755 --- a/PaddleCV/gan/network/STGAN_network.py +++ b/PaddleCV/gan/network/STGAN_network.py @@ -27,37 +27,48 @@ class STGAN_model(object): def __init__(self): pass - def network_G(self, input, label_org, label_trg, cfg, name="generator"): + def network_G(self, + input, + label_org, + label_trg, + cfg, + name="generator", + is_test=False): _a = label_org _b = label_trg z = self.Genc( input, name=name + '_Genc', n_layers=cfg.n_layers, - dim=cfg.g_base_dims) + dim=cfg.g_base_dims, + is_test=is_test) zb = self.GRU(z, fluid.layers.elementwise_sub(_b, _a), name=name + '_GRU', dim=cfg.g_base_dims, - n_layers=cfg.gru_n_layers) if cfg.use_gru else z + n_layers=cfg.gru_n_layers, + is_test=is_test) if cfg.use_gru else z fake_image = self.Gdec( zb, fluid.layers.elementwise_sub(_b, _a), name=name + '_Gdec', dim=cfg.g_base_dims, - n_layers=cfg.n_layers) + n_layers=cfg.n_layers, + is_test=is_test) za = self.GRU(z, fluid.layers.elementwise_sub(_a, _a), name=name + '_GRU', dim=cfg.g_base_dims, - n_layers=cfg.gru_n_layers) if cfg.use_gru else z + n_layers=cfg.gru_n_layers, + is_test=is_test) if cfg.use_gru else z rec_image = self.Gdec( za, fluid.layers.elementwise_sub(_a, _a), name=name + '_Gdec', dim=cfg.g_base_dims, - n_layers=cfg.n_layers) + n_layers=cfg.n_layers, + is_test=is_test) return fake_image, rec_image def network_D(self, input, cfg, name="discriminator"): @@ -74,7 +85,7 @@ class STGAN_model(object): z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0) return fluid.layers.concat([z, ones * a], axis=1) - def Genc(self, input, dim=64, n_layers=5, name='G_enc_'): + def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False): z = input zs = [] for i in range(n_layers): @@ -90,7 +101,8 @@ class STGAN_model(object): name=name + str(i), use_bias=False, relufactor=0.01, - initial='kaiming') + initial='kaiming', + is_test=is_test) zs.append(z) return zs @@ -104,7 +116,8 @@ class STGAN_model(object): kernel_size=3, norm=None, pass_state='lstate', - name='G_gru_'): + name='G_gru_', + is_test=False): zs_ = [zs[-1]] state = self.concat(zs[-1], a) @@ -117,7 +130,8 @@ class STGAN_model(object): kernel_size=kernel_size, norm=norm, pass_state=pass_state, - name=name + str(i)) + name=name + str(i), + is_test=is_test) zs_.insert(0, output[0] + zs[n_layers - 1 - i]) if inject_layers > i: state = self.concat(output[1], a) @@ -132,7 +146,8 @@ class STGAN_model(object): n_layers=5, shortcut_layers=4, inject_layers=4, - name='G_dec_'): + name='G_dec_', + is_test=False): shortcut_layers = min(shortcut_layers, n_layers - 1) inject_layers = min(inject_layers, n_layers - 1) @@ -150,7 +165,8 @@ class STGAN_model(object): norm='batch_norm', activation_fn='relu', use_bias=False, - initial='kaiming') + initial='kaiming', + is_test=is_test) if shortcut_layers > i: z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1) if inject_layers > i: @@ -165,7 +181,8 @@ class STGAN_model(object): name=name + str(i), activation_fn='tanh', use_bias=True, - initial='kaiming') + initial='kaiming', + is_test=is_test) return x def D(self, @@ -220,7 +237,8 @@ class STGAN_model(object): kernel_size=3, norm=None, pass_state='lstate', - name='gru'): + name='gru', + is_test=False): state_ = deconv2d( state, out_channel, @@ -229,7 +247,8 @@ class STGAN_model(object): padding_type='SAME', name=name + '_deconv2d', use_bias=True, - initial='kaiming' + initial='kaiming', + is_test=is_test, ) # upsample and make `channel` identical to `out_channel` reset_gate = conv2d( fluid.layers.concat( @@ -241,7 +260,8 @@ class STGAN_model(object): padding_type='SAME', use_bias=True, name=name + '_reset_gate', - initial='kaiming') + initial='kaiming', + is_test=is_test) update_gate = conv2d( fluid.layers.concat( [in_data, state_], axis=1), @@ -252,7 +272,8 @@ class STGAN_model(object): padding_type='SAME', use_bias=True, name=name + '_update_gate', - initial='kaiming') + initial='kaiming', + is_test=is_test) left_state = reset_gate * state_ new_info = conv2d( fluid.layers.concat( @@ -264,7 +285,8 @@ class STGAN_model(object): name=name + '_info', padding_type='SAME', use_bias=True, - initial='kaiming') + initial='kaiming', + is_test=is_test) output = (1 - update_gate) * state_ + update_gate * new_info if pass_state == 'output': return output, output diff --git a/PaddleCV/gan/network/base_network.py b/PaddleCV/gan/network/base_network.py index 691486ce..ab989df5 100644 --- a/PaddleCV/gan/network/base_network.py +++ b/PaddleCV/gan/network/base_network.py @@ -34,7 +34,7 @@ def cal_padding(img_size, stride, filter_size, dilation=1): return out_size // 2, out_size - out_size // 2 -def norm_layer(input, norm_type='batch_norm', name=None): +def norm_layer(input, norm_type='batch_norm', name=None, is_test=False): if norm_type == 'batch_norm': param_attr = fluid.ParamAttr( name=name + '_w', initializer=fluid.initializer.Constant(1.0)) @@ -44,6 +44,7 @@ def norm_layer(input, norm_type='batch_norm', name=None): input, param_attr=param_attr, bias_attr=bias_attr, + is_test=is_test, moving_mean_name=name + '_mean', moving_variance_name=name + '_var') @@ -133,7 +134,8 @@ def conv2d(input, relufactor=0.0, use_bias=False, padding_type=None, - initial="normal"): + initial="normal", + is_test=False): if padding != 0 and padding_type != None: warnings.warn( @@ -181,7 +183,8 @@ def conv2d(input, param_attr=param_attr, bias_attr=bias_attr) if norm is not None: - conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") + conv = norm_layer( + input=conv, norm_type=norm, name=name + "_norm", is_test=is_test) if activation_fn == 'relu': conv = fluid.layers.relu(conv, name=name + '_relu') elif activation_fn == 'leaky_relu': @@ -214,7 +217,8 @@ def deconv2d(input, use_bias=False, padding_type=None, output_size=None, - initial="normal"): + initial="normal", + is_test=False): if padding != 0 and padding_type != None: warnings.warn( @@ -268,7 +272,8 @@ def deconv2d(input, conv, paddings=outpadding, mode='constant', pad_value=0.0) if norm is not None: - conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") + conv = norm_layer( + input=conv, norm_type=norm, name=name + "_norm", is_test=is_test) if activation_fn == 'relu': conv = fluid.layers.relu(conv, name=name + '_relu') elif activation_fn == 'leaky_relu': @@ -297,7 +302,8 @@ def linear(input, activation_fn=None, relufactor=0.2, name="linear", - initial="normal"): + initial="normal", + is_test=False): param_attr, bias_attr = initial_type( name=name, @@ -316,7 +322,8 @@ def linear(input, name=name) if norm is not None: - linear = norm_layer(input=linear, norm_type=norm, name=name + '_norm') + linear = norm_layer( + input=linear, norm_type=norm, name=name + '_norm', is_test=is_test) if activation_fn == 'relu': linear = fluid.layers.relu(linear, name=name + '_relu') elif activation_fn == 'leaky_relu': diff --git a/PaddleCV/gan/scripts/infer_attgan.sh b/PaddleCV/gan/scripts/infer_attgan.sh new file mode 100644 index 00000000..13e7de17 --- /dev/null +++ b/PaddleCV/gan/scripts/infer_attgan.sh @@ -0,0 +1 @@ +python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128 diff --git a/PaddleCV/gan/scripts/infer_stargan.sh b/PaddleCV/gan/scripts/infer_stargan.sh new file mode 100644 index 00000000..088e33b3 --- /dev/null +++ b/PaddleCV/gan/scripts/infer_stargan.sh @@ -0,0 +1 @@ +python infer.py --model_net StarGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young" diff --git a/PaddleCV/gan/scripts/infer_stgan.sh b/PaddleCV/gan/scripts/infer_stgan.sh new file mode 100644 index 00000000..097e28c8 --- /dev/null +++ b/PaddleCV/gan/scripts/infer_stgan.sh @@ -0,0 +1 @@ +python infer.py --model_net STGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --use_gru True diff --git a/PaddleCV/gan/scripts/run_stargan.sh b/PaddleCV/gan/scripts/run_stargan.sh index 06388968..cce84ab6 100644 --- a/PaddleCV/gan/scripts/run_stargan.sh +++ b/PaddleCV/gan/scripts/run_stargan.sh @@ -1 +1 @@ -python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err +python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err diff --git a/PaddleCV/gan/scripts/run_stgan.sh b/PaddleCV/gan/scripts/run_stgan.sh index 8d20179d..ea8115a2 100644 --- a/PaddleCV/gan/scripts/run_stgan.sh +++ b/PaddleCV/gan/scripts/run_stgan.sh @@ -1 +1 @@ -python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out 2>log_err +python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 20 >log.out 2>log_err diff --git a/PaddleCV/gan/util/utility.py b/PaddleCV/gan/util/utility.py index aacfb1e6..4c6db181 100644 --- a/PaddleCV/gan/util/utility.py +++ b/PaddleCV/gan/util/utility.py @@ -133,6 +133,7 @@ def save_test_image(epoch, elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': for data in zip(A_test_reader()): real_img, label_org, name = data[0] + attr_names = args.selected_attrs.split(',') label_trg = copy.deepcopy(label_org) tensor_img = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor() @@ -148,6 +149,8 @@ def save_test_image(epoch, for j in range(len(label_org)): label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] + label_trg_tmp = check_attribute_conflict( + label_trg_tmp, attr_names[i], attr_names) label_trg_ = list( map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) @@ -230,3 +233,31 @@ class ImagePool(object): return temp else: return image + + +def check_attribute_conflict(label_batch, attr, attrs): + def _set(label, value, attr): + if attr in attrs: + label[attrs.index(attr)] = value + + attr_id = attrs.index(attr) + for label in label_batch: + if attr in ['Bald', 'Receding_Hairline'] and attrs[attr_id] != 0: + _set(label, 0, 'Bangs') + elif attr == 'Bangs' and attrs[attr_id] != 0: + _set(label, 0, 'Bald') + _set(label, 0, 'Receding_Hairline') + elif attr in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair' + ] and attrs[attr_id] != 0: + for a in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: + if a != attr: + _set(label, 0, a) + elif attr in ['Straight_Hair', 'Wavy_Hair'] and attrs[attr_id] != 0: + for a in ['Straight_Hair', 'Wavy_Hair']: + if a != attr: + _set(label, 0, a) + elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0: + for a in ['Mustache', 'No_Beard']: + if a != attr: + _set(label, 0, a) + return label_batch -- GitLab