未验证 提交 2df1c9c3 编写于 作者: L lvmengsi 提交者: GitHub

fix op init (#2491)

fix op init
上级 c271c571
......@@ -79,17 +79,28 @@ def norm_layer(input, norm_type='batch_norm', name=None):
def initial_type(name,
input,
op_type,
fan_out,
init="normal",
use_bias=False,
f_in=0,
filter_size=0,
stddev=0.02):
if init == "kaiming":
fan_in = f_in * filter_size * filter_size
if op_type == 'conv':
fan_in = input.shape[1] * filter_size * filter_size
elif op_type == 'deconv':
fan_in = fan_out * filter_size * filter_size
else:
if len(input.shape) > 2:
fan_in = input.shape[1] * input.shape[2] * input.shape[3]
else:
fan_in = input.shape[1]
bound = 1 / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(
name=name + "_w",
initializer=fluid.initializer.MSRAInitializer(uniform=True))
initializer=fluid.initializer.Uniform(
low=-bound, high=bound))
if use_bias == True:
bias_attr = fluid.ParamAttr(
name=name + '_b',
......@@ -131,9 +142,11 @@ def conv2d(input,
param_attr, bias_attr = initial_type(
name=name,
input=input,
op_type='conv',
fan_out=num_filters,
init=initial,
use_bias=use_bias,
f_in=input.shape[1],
filter_size=filter_size,
stddev=stddev)
......@@ -210,9 +223,11 @@ def deconv2d(input,
param_attr, bias_attr = initial_type(
name=name,
input=input,
op_type='deconv',
fan_out=num_filters,
init=initial,
use_bias=use_bias,
f_in=input.shape[1],
filter_size=filter_size,
stddev=stddev)
......@@ -286,9 +301,11 @@ def linear(input,
param_attr, bias_attr = initial_type(
name=name,
input=input,
op_type='linear',
fan_out=output_size,
init=initial,
use_bias=True,
f_in=input.shape[1],
filter_size=1,
stddev=stddev)
......
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --load_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out #2>log_err
python train.py --model_net AttGAN --dataset celeba --crop_size 170 --load_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out 2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --load_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out #2>log_err
python train.py --model_net STGAN --dataset celeba --crop_size 170 --load_size 128 --train_list ./data/celeba/list_attr_celeba.txt --test_list ./data/celeba/test_list_attr_celeba.txt --gan_mode wgan --batch_size 32 --print_freq 1 --num_discriminator_time 5 --epoch 200 >log.out 2>log_err
......@@ -67,8 +67,6 @@ class GTrainer():
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_G")
optimizer.minimize(self.g_loss, parameter_list=vars)
with open('program_gen.txt', 'w') as f:
print(self.program, file=f)
class DTrainer():
......@@ -140,8 +138,6 @@ class DTrainer():
learning_rate=lr, beta1=0.5, beta2=0.999, name="net_D")
optimizer.minimize(self.d_loss, parameter_list=vars)
with open('program.txt', 'w') as f:
print(self.program, file=f)
def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
def _interpolate(a, b=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册