未验证 提交 f489ca5d 编写于 作者: L lvmengsi 提交者: GitHub

Fix batch norm in infer (#2570)

* fix infer and conflict attribute
上级 bf37c67c
......@@ -27,6 +27,7 @@ import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict
import copy
parser = argparse.ArgumentParser(description=__doc__)
......@@ -86,12 +87,22 @@ def infer(args):
from network.STGAN_network import STGAN_model
model = STGAN_model()
fake, _ = model.network_G(
input, label_org_, label_trg_, cfg=args, name='net_G')
input,
label_org_,
label_trg_,
cfg=args,
name='generator',
is_test=True)
elif args.model_net == 'AttGAN':
from network.AttGAN_network import AttGAN_model
model = AttGAN_model()
fake, _ = model.network_G(
input, label_org_, label_trg_, cfg=args, name='net_G')
input,
label_org_,
label_trg_,
cfg=args,
name='generator',
is_test=True)
else:
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
......@@ -122,6 +133,7 @@ def infer(args):
args, shuffle=False, return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
print("read {}".format(name))
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
......@@ -137,6 +149,8 @@ def infer(args):
label_trg_tmp = copy.deepcopy(label_trg)
for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
......
......@@ -27,17 +27,26 @@ class AttGAN_model(object):
def __init__(self):
pass
def network_G(self, input, label_org, label_trg, cfg, name="generator"):
def network_G(self,
input,
label_org,
label_trg,
cfg,
name="generator",
is_test=False):
_a = label_org
_b = label_trg
z = self.Genc(
input,
name=name + '_Genc',
dim=cfg.g_base_dims,
n_layers=cfg.n_layers)
fake_image = self.Gdec(z, _b, name=name + '_Gdec', dim=cfg.g_base_dims)
n_layers=cfg.n_layers,
is_test=is_test)
fake_image = self.Gdec(
z, _b, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
rec_image = self.Gdec(z, _a, name=name + '_Gdec', dim=cfg.g_base_dims)
rec_image = self.Gdec(
z, _a, name=name + '_Gdec', dim=cfg.g_base_dims, is_test=is_test)
return fake_image, rec_image
def network_D(self, input, cfg, name="discriminator"):
......@@ -54,7 +63,7 @@ class AttGAN_model(object):
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, ones * a], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_'):
def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input
zs = []
for i in range(n_layers):
......@@ -71,7 +80,8 @@ class AttGAN_model(object):
name=name + str(i),
use_bias=False,
relufactor=0.01,
initial='kaiming')
initial='kaiming',
is_test=is_test)
zs.append(z)
return zs
......@@ -83,7 +93,8 @@ class AttGAN_model(object):
n_layers=5,
shortcut_layers=1,
inject_layers=1,
name='G_dec_'):
name='G_dec_',
is_test=False):
shortcut_layers = min(shortcut_layers, n_layers - 1)
inject_layers = min(inject_layers, n_layers - 1)
......@@ -101,7 +112,8 @@ class AttGAN_model(object):
norm='batch_norm',
activation_fn='relu',
use_bias=False,
initial='kaiming')
initial='kaiming',
is_test=is_test)
if shortcut_layers > i:
z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
if inject_layers > i:
......@@ -116,7 +128,8 @@ class AttGAN_model(object):
name=name + str(i),
activation_fn='tanh',
use_bias=True,
initial='kaiming')
initial='kaiming',
is_test=is_test)
return x
def D(self,
......
......@@ -27,37 +27,48 @@ class STGAN_model(object):
def __init__(self):
pass
def network_G(self, input, label_org, label_trg, cfg, name="generator"):
def network_G(self,
input,
label_org,
label_trg,
cfg,
name="generator",
is_test=False):
_a = label_org
_b = label_trg
z = self.Genc(
input,
name=name + '_Genc',
n_layers=cfg.n_layers,
dim=cfg.g_base_dims)
dim=cfg.g_base_dims,
is_test=is_test)
zb = self.GRU(z,
fluid.layers.elementwise_sub(_b, _a),
name=name + '_GRU',
dim=cfg.g_base_dims,
n_layers=cfg.gru_n_layers) if cfg.use_gru else z
n_layers=cfg.gru_n_layers,
is_test=is_test) if cfg.use_gru else z
fake_image = self.Gdec(
zb,
fluid.layers.elementwise_sub(_b, _a),
name=name + '_Gdec',
dim=cfg.g_base_dims,
n_layers=cfg.n_layers)
n_layers=cfg.n_layers,
is_test=is_test)
za = self.GRU(z,
fluid.layers.elementwise_sub(_a, _a),
name=name + '_GRU',
dim=cfg.g_base_dims,
n_layers=cfg.gru_n_layers) if cfg.use_gru else z
n_layers=cfg.gru_n_layers,
is_test=is_test) if cfg.use_gru else z
rec_image = self.Gdec(
za,
fluid.layers.elementwise_sub(_a, _a),
name=name + '_Gdec',
dim=cfg.g_base_dims,
n_layers=cfg.n_layers)
n_layers=cfg.n_layers,
is_test=is_test)
return fake_image, rec_image
def network_D(self, input, cfg, name="discriminator"):
......@@ -74,7 +85,7 @@ class STGAN_model(object):
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, ones * a], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_'):
def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input
zs = []
for i in range(n_layers):
......@@ -90,7 +101,8 @@ class STGAN_model(object):
name=name + str(i),
use_bias=False,
relufactor=0.01,
initial='kaiming')
initial='kaiming',
is_test=is_test)
zs.append(z)
return zs
......@@ -104,7 +116,8 @@ class STGAN_model(object):
kernel_size=3,
norm=None,
pass_state='lstate',
name='G_gru_'):
name='G_gru_',
is_test=False):
zs_ = [zs[-1]]
state = self.concat(zs[-1], a)
......@@ -117,7 +130,8 @@ class STGAN_model(object):
kernel_size=kernel_size,
norm=norm,
pass_state=pass_state,
name=name + str(i))
name=name + str(i),
is_test=is_test)
zs_.insert(0, output[0] + zs[n_layers - 1 - i])
if inject_layers > i:
state = self.concat(output[1], a)
......@@ -132,7 +146,8 @@ class STGAN_model(object):
n_layers=5,
shortcut_layers=4,
inject_layers=4,
name='G_dec_'):
name='G_dec_',
is_test=False):
shortcut_layers = min(shortcut_layers, n_layers - 1)
inject_layers = min(inject_layers, n_layers - 1)
......@@ -150,7 +165,8 @@ class STGAN_model(object):
norm='batch_norm',
activation_fn='relu',
use_bias=False,
initial='kaiming')
initial='kaiming',
is_test=is_test)
if shortcut_layers > i:
z = fluid.layers.concat([z, zs[n_layers - 2 - i]], axis=1)
if inject_layers > i:
......@@ -165,7 +181,8 @@ class STGAN_model(object):
name=name + str(i),
activation_fn='tanh',
use_bias=True,
initial='kaiming')
initial='kaiming',
is_test=is_test)
return x
def D(self,
......@@ -220,7 +237,8 @@ class STGAN_model(object):
kernel_size=3,
norm=None,
pass_state='lstate',
name='gru'):
name='gru',
is_test=False):
state_ = deconv2d(
state,
out_channel,
......@@ -229,7 +247,8 @@ class STGAN_model(object):
padding_type='SAME',
name=name + '_deconv2d',
use_bias=True,
initial='kaiming'
initial='kaiming',
is_test=is_test,
) # upsample and make `channel` identical to `out_channel`
reset_gate = conv2d(
fluid.layers.concat(
......@@ -241,7 +260,8 @@ class STGAN_model(object):
padding_type='SAME',
use_bias=True,
name=name + '_reset_gate',
initial='kaiming')
initial='kaiming',
is_test=is_test)
update_gate = conv2d(
fluid.layers.concat(
[in_data, state_], axis=1),
......@@ -252,7 +272,8 @@ class STGAN_model(object):
padding_type='SAME',
use_bias=True,
name=name + '_update_gate',
initial='kaiming')
initial='kaiming',
is_test=is_test)
left_state = reset_gate * state_
new_info = conv2d(
fluid.layers.concat(
......@@ -264,7 +285,8 @@ class STGAN_model(object):
name=name + '_info',
padding_type='SAME',
use_bias=True,
initial='kaiming')
initial='kaiming',
is_test=is_test)
output = (1 - update_gate) * state_ + update_gate * new_info
if pass_state == 'output':
return output, output
......
......@@ -34,7 +34,7 @@ def cal_padding(img_size, stride, filter_size, dilation=1):
return out_size // 2, out_size - out_size // 2
def norm_layer(input, norm_type='batch_norm', name=None):
def norm_layer(input, norm_type='batch_norm', name=None, is_test=False):
if norm_type == 'batch_norm':
param_attr = fluid.ParamAttr(
name=name + '_w', initializer=fluid.initializer.Constant(1.0))
......@@ -44,6 +44,7 @@ def norm_layer(input, norm_type='batch_norm', name=None):
input,
param_attr=param_attr,
bias_attr=bias_attr,
is_test=is_test,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_var')
......@@ -133,7 +134,8 @@ def conv2d(input,
relufactor=0.0,
use_bias=False,
padding_type=None,
initial="normal"):
initial="normal",
is_test=False):
if padding != 0 and padding_type != None:
warnings.warn(
......@@ -181,7 +183,8 @@ def conv2d(input,
param_attr=param_attr,
bias_attr=bias_attr)
if norm is not None:
conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm")
conv = norm_layer(
input=conv, norm_type=norm, name=name + "_norm", is_test=is_test)
if activation_fn == 'relu':
conv = fluid.layers.relu(conv, name=name + '_relu')
elif activation_fn == 'leaky_relu':
......@@ -214,7 +217,8 @@ def deconv2d(input,
use_bias=False,
padding_type=None,
output_size=None,
initial="normal"):
initial="normal",
is_test=False):
if padding != 0 and padding_type != None:
warnings.warn(
......@@ -268,7 +272,8 @@ def deconv2d(input,
conv, paddings=outpadding, mode='constant', pad_value=0.0)
if norm is not None:
conv = norm_layer(input=conv, norm_type=norm, name=name + "_norm")
conv = norm_layer(
input=conv, norm_type=norm, name=name + "_norm", is_test=is_test)
if activation_fn == 'relu':
conv = fluid.layers.relu(conv, name=name + '_relu')
elif activation_fn == 'leaky_relu':
......@@ -297,7 +302,8 @@ def linear(input,
activation_fn=None,
relufactor=0.2,
name="linear",
initial="normal"):
initial="normal",
is_test=False):
param_attr, bias_attr = initial_type(
name=name,
......@@ -316,7 +322,8 @@ def linear(input,
name=name)
if norm is not None:
linear = norm_layer(input=linear, norm_type=norm, name=name + '_norm')
linear = norm_layer(
input=linear, norm_type=norm, name=name + '_norm', is_test=is_test)
if activation_fn == 'relu':
linear = fluid.layers.relu(linear, name=name + '_relu')
elif activation_fn == 'leaky_relu':
......
python infer.py --model_net AttGAN --init_model output/checkpoints/199/ --dataset_dir "data/celeba/" --image_size 128
python infer.py --model_net StarGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young"
python infer.py --model_net STGAN --init_model output/checkpoints/19/ --dataset_dir "data/celeba/" --image_size 128 --use_gru True
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 200 > log_out 2>log_err
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 16 --epoch 20 > log_out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 20 >log.out 2>log_err
......@@ -133,6 +133,7 @@ def save_test_image(epoch,
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in zip(A_test_reader()):
real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
......@@ -148,6 +149,8 @@ def save_test_image(epoch,
for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg_ = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
......@@ -230,3 +233,31 @@ class ImagePool(object):
return temp
else:
return image
def check_attribute_conflict(label_batch, attr, attrs):
def _set(label, value, attr):
if attr in attrs:
label[attrs.index(attr)] = value
attr_id = attrs.index(attr)
for label in label_batch:
if attr in ['Bald', 'Receding_Hairline'] and attrs[attr_id] != 0:
_set(label, 0, 'Bangs')
elif attr == 'Bangs' and attrs[attr_id] != 0:
_set(label, 0, 'Bald')
_set(label, 0, 'Receding_Hairline')
elif attr in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
] and attrs[attr_id] != 0:
for a in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Straight_Hair', 'Wavy_Hair'] and attrs[attr_id] != 0:
for a in ['Straight_Hair', 'Wavy_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0:
for a in ['Mustache', 'No_Beard']:
if a != attr:
_set(label, 0, a)
return label_batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册