未验证 提交 f489ca5d 编写于 作者: L lvmengsi 提交者: GitHub

Fix batch norm in infer (#2570)

* fix infer and conflict attribute
上级 bf37c67c
...@@ -27,6 +27,7 @@ import imageio ...@@ -27,6 +27,7 @@ import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict
import copy import copy
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -86,12 +87,22 @@ def infer(args): ...@@ -86,12 +87,22 @@ def infer(args):
from network.STGAN_network import STGAN_model from network.STGAN_network import STGAN_model
model = STGAN_model() model = STGAN_model()
fake, _ = model.network_G( fake, _ = model.network_G(
input, label_org_, label_trg_, cfg=args, name='net_G') input,
label_org_,
label_trg_,
cfg=args,
name='generator',
is_test=True)
elif args.model_net == 'AttGAN': elif args.model_net == 'AttGAN':
from network.AttGAN_network import AttGAN_model from network.AttGAN_network import AttGAN_model
model = AttGAN_model() model = AttGAN_model()
fake, _ = model.network_G( fake, _ = model.network_G(
input, label_org_, label_trg_, cfg=args, name='net_G') input,
label_org_,
label_trg_,
cfg=args,
name='generator',
is_test=True)
else: else:
raise NotImplementedError("model_net {} is not support".format( raise NotImplementedError("model_net {} is not support".format(
args.model_net)) args.model_net))
...@@ -122,6 +133,7 @@ def infer(args): ...@@ -122,6 +133,7 @@ def infer(args):
args, shuffle=False, return_name=True) args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
print("read {}".format(name)) print("read {}".format(name))
label_trg = copy.deepcopy(label_org) label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
...@@ -137,6 +149,8 @@ def infer(args): ...@@ -137,6 +149,8 @@ def infer(args):
label_trg_tmp = copy.deepcopy(label_trg) label_trg_tmp = copy.deepcopy(label_trg)
for j in range(len(label_org)): for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg_ = list( label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)): for j in range(len(label_org)):
......
...@@ -27,17 +27,26 @@ class AttGAN_model(object): ...@@ -27,17 +27,26 @@ class AttGAN_model(object):
def __init__(self): def __init__(self):
pass pass
def network_G(self, input, label_org, label_trg, cfg, name="generator"): def network_G(self,
input,
label_org,
label_trg,
cfg,
name="generator",
is_test=False):
_a = label_org _a = label_org
_b = label_trg _b = label_trg
z = self.Genc( z = self.Genc(
input, input,
name=name + '_Genc', name=name + '_Genc',
dim=cfg.g_base_dims, dim=cfg.g_base_dims,
n_layers=cfg.n_layers) n_layers=cfg.n_layers,
fake_image = self.Gdec(z, _b, name=name + '_Gdec', dim=cfg.g_base_dims) is_test=is_test)
fake_image = self.Gdec(
z, _b, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
rec_image = self.Gdec(z, _a, name=name + '_Gdec', dim=cfg.g_base_dims) rec_image = self.Gdec(
z, _a, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
return fake_image, rec_image return fake_image, rec_image
def network_D(self, input, cfg, name="discriminator"): def network_D(self, input, cfg, name="discriminator"):
...@@ -54,7 +63,7 @@ class AttGAN_model(object): ...@@ -54,7 +63,7 @@ class AttGAN_model(object):
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0) z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, ones * a], axis=1) return fluid.layers.concat([z, ones * a], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_'): def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input z = input
zs = [] zs = []
for i in range(n_layers): for i in range(n_layers):
...@@ -71,7 +80,8 @@ class AttGAN_model(object): ...@@ -71,7 +80,8 @@ class AttGAN_model(object):
name=name + str(i), name=name + str(i),
use_bias=False, use_bias=False,
relufactor=0.01, relufactor=0.01,
initial='kaiming') initial='kaiming',
is_test=is_test)
zs.append(z) zs.append(z)
return zs return zs
...@@ -83,7 +93,8 @@ class AttGAN_model(object): ...@@ -83,7 +93,8 @@ class AttGAN_model(object):
n_layers=5, n_layers=5,
shortcut_layers=1, shortcut_layers=1,
inject_layers=1, inject_layers=1,
name='G_dec_'): name='G_dec_',
is_test=False):
shortcut_layers = min(shortcut_layers, n_layers - 1) shortcut_layers = min(shortcut_layers, n_layers - 1)
inject_layers = min(inject_layers, n_layers - 1) inject_layers = min(inject_layers, n_layers - 1)
...@@ -101,7 +112,8 @@ class AttGAN_model(object): ...@@ -101,7 +112,8 @@ class AttGAN_model(object):
norm='batch_norm', norm='batch_norm',
activation_fn='relu', activation_fn='relu',
use_bias=False, use_bias=False,
initial='kaiming') initial='kaiming',
is_test=is_test)
if shortcut_layers > i: if shortcut_layers > i:
z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1) z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
if inject_layers > i: if inject_layers > i:
...@@ -116,7 +128,8 @@ class AttGAN_model(object): ...@@ -116,7 +128,8 @@ class AttGAN_model(object):
name=name + str(i), name=name + str(i),
activation_fn='tanh', activation_fn='tanh',
use_bias=True, use_bias=True,
initial='kaiming') initial='kaiming',
is_test=is_test)
return x return x
def D(self, def D(self,
......
...@@ -27,37 +27,48 @@ class STGAN_model(object): ...@@ -27,37 +27,48 @@ class STGAN_model(object):
def __init__(self): def __init__(self):
pass pass
def network_G(self, input, label_org, label_trg, cfg, name="generator"): def network_G(self,
input,
label_org,
label_trg,
cfg,
name="generator",
is_test=False):
_a = label_org _a = label_org
_b = label_trg _b = label_trg
z = self.Genc( z = self.Genc(
input, input,
name=name + '_Genc', name=name + '_Genc',
n_layers=cfg.n_layers, n_layers=cfg.n_layers,
dim=cfg.g_base_dims) dim=cfg.g_base_dims,
is_test=is_test)
zb = self.GRU(z, zb = self.GRU(z,
fluid.layers.elementwise_sub(_b, _a), fluid.layers.elementwise_sub(_b, _a),
name=name + '_GRU', name=name + '_GRU',
dim=cfg.g_base_dims, dim=cfg.g_base_dims,
n_layers=cfg.gru_n_layers) if cfg.use_gru else z n_layers=cfg.gru_n_layers,
is_test=is_test) if cfg.use_gru else z
fake_image = self.Gdec( fake_image = self.Gdec(
zb, zb,
fluid.layers.elementwise_sub(_b, _a), fluid.layers.elementwise_sub(_b, _a),
name=name + '_Gdec', name=name + '_Gdec',
dim=cfg.g_base_dims, dim=cfg.g_base_dims,
n_layers=cfg.n_layers) n_layers=cfg.n_layers,
is_test=is_test)
za = self.GRU(z, za = self.GRU(z,
fluid.layers.elementwise_sub(_a, _a), fluid.layers.elementwise_sub(_a, _a),
name=name + '_GRU', name=name + '_GRU',
dim=cfg.g_base_dims, dim=cfg.g_base_dims,
n_layers=cfg.gru_n_layers) if cfg.use_gru else z n_layers=cfg.gru_n_layers,
is_test=is_test) if cfg.use_gru else z
rec_image = self.Gdec( rec_image = self.Gdec(
za, za,
fluid.layers.elementwise_sub(_a, _a), fluid.layers.elementwise_sub(_a, _a),
name=name + '_Gdec', name=name + '_Gdec',
dim=cfg.g_base_dims, dim=cfg.g_base_dims,
n_layers=cfg.n_layers) n_layers=cfg.n_layers,
is_test=is_test)
return fake_image, rec_image return fake_image, rec_image
def network_D(self, input, cfg, name="discriminator"): def network_D(self, input, cfg, name="discriminator"):
...@@ -74,7 +85,7 @@ class STGAN_model(object): ...@@ -74,7 +85,7 @@ class STGAN_model(object):
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0) z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, ones * a], axis=1) return fluid.layers.concat([z, ones * a], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_'): def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input z = input
zs = [] zs = []
for i in range(n_layers): for i in range(n_layers):
...@@ -90,7 +101,8 @@ class STGAN_model(object): ...@@ -90,7 +101,8 @@ class STGAN_model(object):
name=name + str(i), name=name + str(i),
use_bias=False, use_bias=False,
relufactor=0.01, relufactor=0.01,
initial='kaiming') initial='kaiming',
is_test=is_test)
zs.append(z) zs.append(z)
return zs return zs
...@@ -104,7 +116,8 @@ class STGAN_model(object): ...@@ -104,7 +116,8 @@ class STGAN_model(object):
kernel_size=3, kernel_size=3,
norm=None, norm=None,
pass_state='lstate', pass_state='lstate',
name='G_gru_'): name='G_gru_',
is_test=False):
zs_ = [zs[-1]] zs_ = [zs[-1]]
state = self.concat(zs[-1], a) state = self.concat(zs[-1], a)
...@@ -117,7 +130,8 @@ class STGAN_model(object): ...@@ -117,7 +130,8 @@ class STGAN_model(object):
kernel_size=kernel_size, kernel_size=kernel_size,
norm=norm, norm=norm,
pass_state=pass_state, pass_state=pass_state,
name=name + str(i)) name=name + str(i),
is_test=is_test)
zs_.insert(0, output[0] + zs[n_layers - 1 - i]) zs_.insert(0, output[0] + zs[n_layers - 1 - i])
if inject_layers > i: if inject_layers > i:
state = self.concat(output[1], a) state = self.concat(output[1], a)
...@@ -132,7 +146,8 @@ class STGAN_model(object): ...@@ -132,7 +146,8 @@ class STGAN_model(object):
n_layers=5, n_layers=5,
shortcut_layers=4, shortcut_layers=4,
inject_layers=4, inject_layers=4,
name='G_dec_'): name='G_dec_',
is_test=False):
shortcut_layers = min(shortcut_layers, n_layers - 1) shortcut_layers = min(shortcut_layers, n_layers - 1)
inject_layers = min(inject_layers, n_layers - 1) inject_layers = min(inject_layers, n_layers - 1)
...@@ -150,7 +165,8 @@ class STGAN_model(object): ...@@ -150,7 +165,8 @@ class STGAN_model(object):
norm='batch_norm', norm='batch_norm',
activation_fn='relu', activation_fn='relu',
use_bias=False, use_bias=False,
initial='kaiming') initial='kaiming',
is_test=is_test)
if shortcut_layers > i: if shortcut_layers > i:
z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1) z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
if inject_layers > i: if inject_layers > i:
...@@ -165,7 +181,8 @@ class STGAN_model(object): ...@@ -165,7 +181,8 @@ class STGAN_model(object):
name=name + str(i), name=name + str(i),
activation_fn='tanh', activation_fn='tanh',
use_bias=True, use_bias=True,
initial='kaiming') initial='kaiming',
is_test=is_test)
return x return x
def D(self, def D(self,
...@@ -220,7 +237,8 @@ class STGAN_model(object): ...@@ -220,7 +237,8 @@ class STGAN_model(object):
kernel_size=3, kernel_size=3,
norm=None, norm=None,
pass_state='lstate', pass_state='lstate',
name='gru'): name='gru',
is_test=False):
state_ = deconv2d( state_ = deconv2d(
state, state,
out_channel, out_channel,
...@@ -229,7 +247,8 @@ class STGAN_model(object): ...@@ -229,7 +247,8 @@ class STGAN_model(object):
padding_type='SAME', padding_type='SAME',
name=name + '_deconv2d', name=name + '_deconv2d',
use_bias=True, use_bias=True,
initial='kaiming' initial='kaiming',
is_test=is_test,
) # upsample and make `channel` identical to `out_channel` ) # upsample and make `channel` identical to `out_channel`
reset_gate = conv2d( reset_gate = conv2d(
fluid.layers.concat( fluid.layers.concat(
...@@ -241,7 +260,8 @@ class STGAN_model(object): ...@@ -241,7 +260,8 @@ class STGAN_model(object):
padding_type='SAME', padding_type='SAME',
use_bias=True, use_bias=True,
name=name + '_reset_gate', name=name + '_reset_gate',
initial='kaiming') initial='kaiming',
is_test=is_test)
update_gate = conv2d( update_gate = conv2d(
fluid.layers.concat( fluid.layers.concat(
[in_data, state_], axis=1), [in_data, state_], axis=1),
...@@ -252,7 +272,8 @@ class STGAN_model(object): ...@@ -252,7 +272,8 @@ class STGAN_model(object):
padding_type='SAME', padding_type='SAME',
use_bias=True, use_bias=True,
name=name + '_update_gate', name=name + '_update_gate',
initial='kaiming') initial='kaiming',
is_test=is_test)
left_state = reset_gate * state_ left_state = reset_gate * state_
new_info = conv2d( new_info = conv2d(
fluid.layers.concat( fluid.layers.concat(
...@@ -264,7 +285,8 @@ class STGAN_model(object): ...@@ -264,7 +285,8 @@ class STGAN_model(object):
name=name + '_info', name=name + '_info',
padding_type='SAME', padding_type='SAME',
use_bias=True, use_bias=True,
initial='kaiming') initial='kaiming',
is_test=is_test)
output = (1 - update_gate) * state_ + update_gate * new_info output = (1 - update_gate) * state_ + update_gate * new_info
if pass_state == 'output': if pass_state == 'output':
return output, output return output, output
......
...@@ -34,7 +34,7 @@ def cal_padding(img_size, stride, filter_size, dilation=1): ...@@ -34,7 +34,7 @@ def cal_padding(img_size, stride, filter_size, dilation=1):
return out_size // 2, out_size - out_size // 2 return out_size // 2, out_size - out_size // 2
def norm_layer(input, norm_type='batch_norm', name=None): def norm_layer(input, norm_type='batch_norm', name=None, is_test=False):
if norm_type == 'batch_norm': if norm_type == 'batch_norm':
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
name=name + '_w', initializer=fluid.initializer.Constant(1.0)) name=name + '_w', initializer=fluid.initializer.Constant(1.0))
...@@ -44,6 +44,7 @@ def norm_layer(input, norm_type='batch_norm', name=None): ...@@ -44,6 +44,7 @@ def norm_layer(input, norm_type='batch_norm', name=None):
input, input,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
is_test=is_test,
moving_mean_name=name + '_mean', moving_mean_name=name + '_mean',
moving_variance_name=name + '_var') moving_variance_name=name + '_var')
...@@ -133,7 +134,8 @@ def conv2d(input, ...@@ -133,7 +134,8 @@ def conv2d(input,
relufactor=0.0, relufactor=0.0,
use_bias=False, use_bias=False,
padding_type=None, padding_type=None,
initial="normal"): initial="normal",
is_test=False):
if padding != 0 and padding_type != None: if padding != 0 and padding_type != None:
warnings.warn( warnings.warn(
...@@ -181,7 +183,8 @@ def conv2d(input, ...@@ -181,7 +183,8 @@ def conv2d(input,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
if norm is not None: if norm is not None:
conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") conv = norm_layer(
input=conv, norm_type=norm, name=name + "_norm", is_test=is_test)
if activation_fn == 'relu': if activation_fn == 'relu':
conv = fluid.layers.relu(conv, name=name + '_relu') conv = fluid.layers.relu(conv, name=name + '_relu')
elif activation_fn == 'leaky_relu': elif activation_fn == 'leaky_relu':
...@@ -214,7 +217,8 @@ def deconv2d(input, ...@@ -214,7 +217,8 @@ def deconv2d(input,
use_bias=False, use_bias=False,
padding_type=None, padding_type=None,
output_size=None, output_size=None,
initial="normal"): initial="normal",
is_test=False):
if padding != 0 and padding_type != None: if padding != 0 and padding_type != None:
warnings.warn( warnings.warn(
...@@ -268,7 +272,8 @@ def deconv2d(input, ...@@ -268,7 +272,8 @@ def deconv2d(input,
conv, paddings=outpadding, mode='constant', pad_value=0.0) conv, paddings=outpadding, mode='constant', pad_value=0.0)
if norm is not None: if norm is not None:
conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm") conv = norm_layer(
input=conv, norm_type=norm, name=name + "_norm", is_test=is_test)
if activation_fn == 'relu': if activation_fn == 'relu':
conv = fluid.layers.relu(conv, name=name + '_relu') conv = fluid.layers.relu(conv, name=name + '_relu')
elif activation_fn == 'leaky_relu': elif activation_fn == 'leaky_relu':
...@@ -297,7 +302,8 @@ def linear(input, ...@@ -297,7 +302,8 @@ def linear(input,
activation_fn=None, activation_fn=None,
relufactor=0.2, relufactor=0.2,
name="linear", name="linear",
initial="normal"): initial="normal",
is_test=False):
param_attr, bias_attr = initial_type( param_attr, bias_attr = initial_type(
name=name, name=name,
...@@ -316,7 +322,8 @@ def linear(input, ...@@ -316,7 +322,8 @@ def linear(input,
name=name) name=name)
if norm is not None: if norm is not None:
linear = norm_layer(input=linear, norm_type=norm, name=name + '_norm') linear = norm_layer(
input=linear, norm_type=norm, name=name + '_norm', is_test=is_test)
if activation_fn == 'relu': if activation_fn == 'relu':
linear = fluid.layers.relu(linear, name=name + '_relu') linear = fluid.layers.relu(linear, name=name + '_relu')
elif activation_fn == 'leaky_relu': elif activation_fn == 'leaky_relu':
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128
python infer.py --model_net StarGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young"
python infer.py --model_net STGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --use_gru True
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out 2>log_err python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 20 >log.out 2>log_err
...@@ -133,6 +133,7 @@ def save_test_image(epoch, ...@@ -133,6 +133,7 @@ def save_test_image(epoch,
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN': elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in zip(A_test_reader()): for data in zip(A_test_reader()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
label_trg = copy.deepcopy(label_org) label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor()
...@@ -148,6 +149,8 @@ def save_test_image(epoch, ...@@ -148,6 +149,8 @@ def save_test_image(epoch,
for j in range(len(label_org)): for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg_ = list( label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
...@@ -230,3 +233,31 @@ class ImagePool(object): ...@@ -230,3 +233,31 @@ class ImagePool(object):
return temp return temp
else: else:
return image return image
def check_attribute_conflict(label_batch, attr, attrs):
def _set(label, value, attr):
if attr in attrs:
label[attrs.index(attr)] = value
attr_id = attrs.index(attr)
for label in label_batch:
if attr in ['Bald', 'Receding_Hairline'] and attrs[attr_id] != 0:
_set(label, 0, 'Bangs')
elif attr == 'Bangs' and attrs[attr_id] != 0:
_set(label, 0, 'Bald')
_set(label, 0, 'Receding_Hairline')
elif attr in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
] and attrs[attr_id] != 0:
for a in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Straight_Hair', 'Wavy_Hair'] and attrs[attr_id] != 0:
for a in ['Straight_Hair', 'Wavy_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0:
for a in ['Mustache', 'No_Beard']:
if a != attr:
_set(label, 0, a)
return label_batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册