未验证 提交 b0d375a8 编写于 作者: C ceci3 提交者: GitHub

Update GAN to 1.8 (#4599)

* fix init

* update

* update_1.8

* update paddle.reader
上级 4d7ec517
......@@ -79,10 +79,11 @@ def train(args):
g_program_test = dg_program.clone(for_test=True)
dg_logit = D_cond(g_img, conditions)
dg_logit_shape = fluid.layers.shape(dg_logit)
dg_loss = loss(
dg_logit,
fluid.layers.fill_constant_batch_size_like(
input=noise, dtype='float32', shape=[-1, 1], value=1.0))
fluid.layers.fill_constant(
dtype='float32', shape=[dg_logit_shape[0], 1], value=1.0))
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
......@@ -96,11 +97,11 @@ def train(args):
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
if args.run_ce:
train_reader = paddle.batch(
train_reader = fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(
train_reader = fluid.io.batch(
fluid.io.shuffle(
paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
......
......@@ -74,10 +74,11 @@ def train(args):
g_program_test = dg_program.clone(for_test=True)
dg_logit = D(g_img)
dg_logit_shape = fluid.layers.shape(dg_logit)
dg_loss = loss(
dg_logit,
fluid.layers.fill_constant_batch_size_like(
input=noise, dtype='float32', shape=[-1, 1], value=1.0))
fluid.layers.fill_constant(
shape=[dg_logit_shape[0], 1], dtype='float32', value=1.0))
opt = fluid.optimizer.Adam(learning_rate=LEARNING_RATE)
......@@ -92,11 +93,11 @@ def train(args):
exe.run(fluid.default_startup_program())
if args.run_ce:
train_reader = paddle.batch(
train_reader = fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=args.batch_size)
else:
train_reader = paddle.batch(
paddle.reader.shuffle(
train_reader = fluid.io.batch(
fluid.io.shuffle(
paddle.dataset.mnist.train(), buf_size=60000),
batch_size=args.batch_size)
......
......@@ -34,6 +34,7 @@ use_cudnn = True
if 'ce_mode' in os.environ:
use_cudnn = False
def bn(x, name=None, act='relu'):
if name is None:
name = get_parent_function_name()
......@@ -100,8 +101,11 @@ def deconv(x,
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
ones = fluid.layers.fill_constant_batch_size_like(
x, [-1, y.shape[1], x.shape[2], x.shape[3]], "float32", 1.0)
x_shape = fluid.layers.shape(x)
ones = fluid.layers.fill_constant(
shape=[x_shape[0], y.shape[1], x.shape[2], x.shape[3]],
dtype='float32',
value=1.0)
return fluid.layers.concat([x, ones * y], 1)
......
......@@ -22,6 +22,7 @@ import argparse
import struct
import os
import paddle
import paddle.fluid as fluid
import random
import sys
......@@ -520,8 +521,8 @@ class data_reader(object):
train_labels = os.path.join(self.cfg.data_dir, self.cfg.dataset,
"train-labels-idx1-ubyte.gz")
train_reader = paddle.batch(
paddle.reader.shuffle(
train_reader = fluid.io.batch(
fluid.io.shuffle(
mnist_reader_creator(train_images, train_labels, 100),
buf_size=60000),
batch_size=self.cfg.batch_size)
......
......@@ -60,9 +60,14 @@ class AttGAN_model(object):
def concat(self, z, a):
"""Concatenate attribute vector on feature map axis."""
ones = fluid.layers.fill_constant_batch_size_like(
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, fluid.layers.elementwise_mul(ones, a, axis=0)], axis=1)
batch = fluid.layers.shape(z)[0]
ones = fluid.layers.fill_constant(
shape=[batch, a.shape[1], z.shape[2], z.shape[3]],
dtype="float32",
value=1.0)
return fluid.layers.concat(
[z, fluid.layers.elementwise_mul(
ones, a, axis=0)], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input
......
......@@ -82,9 +82,14 @@ class STGAN_model(object):
def concat(self, z, a):
"""Concatenate attribute vector on feature map axis."""
ones = fluid.layers.fill_constant_batch_size_like(
z, [-1, a.shape[1], z.shape[2], z.shape[3]], "float32", 1.0)
return fluid.layers.concat([z, fluid.layers.elementwise_mul(ones, a, axis=0)], axis=1)
batch = fluid.layers.shape(z)[0]
ones = fluid.layers.fill_constant(
shape=[batch, a.shape[1], z.shape[2], z.shape[3]],
dtype="float32",
value=1.0)
return fluid.layers.concat(
[z, fluid.layers.elementwise_mul(
ones, a, axis=0)], axis=1)
def Genc(self, input, dim=64, n_layers=5, name='G_enc_', is_test=False):
z = input
......
......@@ -42,7 +42,9 @@ def norm_layer(input,
if norm_type == 'batch_norm':
if affine == True:
param_attr = fluid.ParamAttr(
name=name + '_w', initializer=fluid.initializer.Normal(loc=1.0, scale=0.02))
name=name + '_w',
initializer=fluid.initializer.Normal(
loc=1.0, scale=0.02))
bias_attr = fluid.ParamAttr(
name=name + '_b',
initializer=fluid.initializer.Constant(value=0.0))
......@@ -366,8 +368,11 @@ def linear(input,
def conv_cond_concat(x, y):
ones = fluid.layers.fill_constant_batch_size_like(
x, [-1, y.shape[1], x.shape[2], x.shape[3]], "float32", 1.0)
batch = fluid.layers.shape(x)[0]
ones = fluid.layers.fill_constant(
shape=[ones, y.shape[1], x.shape[2], x.shape[3]],
dtype="float32",
value=1.0)
out = fluid.layers.concat([x, ones * y], 1)
return out
......
......@@ -44,11 +44,9 @@ class GTrainer():
self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_fake,
shape=self.pred_fake.shape,
value=1.0,
dtype='float32')
fake_shape = fluid.layers.shape(self.pred_fake)
ones = fluid.layers.fill_constant(
shape=fake_shape, value=1.0, dtype='float32')
self.g_loss_fake = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
......@@ -106,11 +104,9 @@ class DTrainer():
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
value=1.0,
dtype='float32')
real_shape = fluid.layers.shape(self.pred_real)
ones = fluid.layers.fill_constant(
shape=real_shape, value=1.0, dtype='float32')
self.d_loss_real = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
......@@ -145,31 +141,31 @@ class DTrainer():
def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
def _interpolate(a, b=None):
a_shape = fluid.layers.shape(a)
if b is None:
if cfg.enable_ce:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0, seed=1)
beta = fluid.layers.uniform_random(
shape=a_shape, min=0.0, max=1.0, seed=1)
else:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0)
beta = fluid.layers.uniform_random(
shape=a_shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean(
a, dim=list(range(len(a.shape))), keep_dim=True)
a, dim=list(range(len(a.shape))))
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean),
dim=list(range(len(a.shape))),
keep_dim=True)
dim=list(range(len(a.shape))))
b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]]
if cfg.enable_ce:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0, seed=1)
alpha = fluid.layers.uniform_random(
shape=a_shape[0], min=0.0, max=1.0, seed=1)
else:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0)
alpha = fluid.layers.uniform_random(
shape=a_shape[0], min=0.0, max=1.0)
inner = fluid.layers.elementwise_mul((b-a), alpha, axis=0) + a
inner = fluid.layers.elementwise_mul((b - a), alpha, axis=0) + a
return inner
x = _interpolate(real, fake)
......@@ -336,7 +332,7 @@ class AttGAN(object):
if self.cfg.enable_ce:
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
for epoch_id in range(self.cfg.epoch):
......@@ -379,7 +375,7 @@ class AttGAN(object):
sys.stdout.flush()
batch_id += 1
if self.cfg.enable_ce and batch_id == 100:
break
break
if self.cfg.run_test:
image_name = fluid.data(
......@@ -402,17 +398,23 @@ class AttGAN(object):
test_loader)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer,
"net_D")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
# used for continuous evaluation
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
print("kpis\tattgan_g_loss_fake_card{}\t{}".format(device_num, g_loss_fake[0]))
print("kpis\tattgan_g_loss_rec_card{}\t{}".format(device_num, g_loss_rec[0]))
print("kpis\tattgan_g_loss_cls_card{}\t{}".format(device_num, g_loss_cls[0]))
print("kpis\tattgan_d_loss_real_card{}\t{}".format(device_num, d_loss_real[0]))
print("kpis\tattgan_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
print("kpis\tattgan_d_loss_gp_card{}\t{}".format(device_num,d_loss_gp[0]))
print("kpis\tattgan_Batch_time_cost_card{}\t{}".format(device_num,batch_time))
device_num = fluid.core.get_cuda_device_count(
) if self.cfg.use_gpu else 1
print("kpis\tattgan_g_loss_fake_card{}\t{}".format(
device_num, g_loss_fake[0]))
print("kpis\tattgan_g_loss_rec_card{}\t{}".format(
device_num, g_loss_rec[0]))
print("kpis\tattgan_g_loss_cls_card{}\t{}".format(
device_num, g_loss_cls[0]))
print("kpis\tattgan_d_loss_real_card{}\t{}".format(
device_num, d_loss_real[0]))
print("kpis\tattgan_d_loss_fake_card{}\t{}".format(
device_num, d_loss_fake[0]))
print("kpis\tattgan_d_loss_gp_card{}\t{}".format(device_num,
d_loss_gp[0]))
print("kpis\tattgan_Batch_time_cost_card{}\t{}".format(
device_num, batch_time))
......@@ -37,8 +37,9 @@ class GTrainer():
self.fake = model.network_G(input, conditions, name="G")
self.infer_program = self.program.clone(for_test=True)
d_fake = model.network_D(self.fake, conditions, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like(
input=input, dtype='float32', shape=[-1, 1], value=1.0)
batch = fluid.layers.shape(input)[0]
fake_labels = fluid.layers.fill_constant(
dtype='float32', shape=[batch, 1], value=1.0)
self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels))
......
......@@ -38,8 +38,9 @@ class GTrainer():
self.fake = model.network_G(input, name='G')
self.infer_program = self.program.clone(for_test=True)
d_fake = model.network_D(self.fake, name="D")
fake_labels = fluid.layers.fill_constant_batch_size_like(
input, dtype='float32', shape=[-1, 1], value=1.0)
batch = fluid.layers.shape(input)[0]
fake_labels = fluid.layers.fill_constant(
dtype='float32', shape=[batch, 1], value=1.0)
self.g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=fake_labels))
......
......@@ -33,10 +33,10 @@ class GTrainer():
self.infer_program = self.program.clone()
AB = fluid.layers.concat([input_A, self.fake_B], 1)
self.pred = model.network_D(AB, "discriminator", cfg)
batch = fluid.layers.shape(self.pred)[0]
if cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred,
shape=self.pred.shape,
ones = fluid.layers.fill_constant(
shape=[batch] + list(self.pred.shape[1:]),
value=1,
dtype='float32')
self.g_loss_gan = fluid.layers.reduce_mean(
......@@ -49,9 +49,8 @@ class GTrainer():
self.pred,
[-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
inplace=True)
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred,
shape=self.pred.shape,
ones = fluid.layers.fill_constant(
shape=[batch] + list(self.pred.shape[1:]),
value=1,
dtype='float32')
self.g_loss_gan = fluid.layers.mean(
......@@ -106,10 +105,10 @@ class DTrainer():
self.real_AB, "discriminator", cfg=cfg)
self.pred_fake = model.network_D(
self.fake_AB, "discriminator", cfg=cfg)
batch = fluid.layers.shape(input_A)[0]
if cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
ones = fluid.layers.fill_constant(
shape=[batch] + list(self.pred_real.shape[1:]),
value=1,
dtype='float32')
self.d_loss_real = fluid.layers.reduce_mean(
......@@ -128,14 +127,12 @@ class DTrainer():
self.pred_fake,
[-1, pred_shape[1] * pred_shape[2] * pred_shape[3]],
inplace=True)
zeros = fluid.layers.fill_constant_batch_size_like(
input=self.pred_fake,
shape=self.pred_fake.shape,
zeros = fluid.layers.fill_constant(
shape=[batch] + list(self.pred_fake.shape[1:]),
value=0,
dtype='float32')
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
ones = fluid.layers.fill_constant(
shape=[batch] + list(self.pred_real.shape[1:]),
value=1,
dtype='float32')
self.d_loss_real = fluid.layers.mean(
......@@ -283,7 +280,8 @@ class Pix2pix(object):
devices_num = utility.get_device_num(self.cfg)
fake_per_device = int(len(fake_B_tmp) / devices_num)
for dev in range(devices_num):
tensor[dev]['input_fake'] = fake_B_tmp[dev * fake_per_device : (dev+1) * fake_per_device]
tensor[dev]['input_fake'] = fake_B_tmp[
dev * fake_per_device:(dev + 1) * fake_per_device]
# optimize the discriminator network
d_loss_real, d_loss_fake = exe.run(dis_trainer_program,
......@@ -338,10 +336,8 @@ class Pix2pix(object):
A_id2name=self.id2name)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer,
"net_D")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count(
) if self.cfg.use_gpu else 1
......
......@@ -144,11 +144,9 @@ class DTrainer():
#####gan loss
self.gan_loss_fake = 0
for pred_i in self.pred_fake:
zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
pred_shape = fluid.layers.shape(pred_i[-1])
zeros = fluid.layers.fill_constant(
shape=pred_shape, value=0, dtype='float32')
if isinstance(pred_i, list):
pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(-1 * pred_i - 1, zeros)
......@@ -158,11 +156,9 @@ class DTrainer():
self.gan_loss_real = 0
for pred_i in self.pred_real:
zeros = fluid.layers.fill_constant_batch_size_like(
input=pred_i[-1],
shape=pred_i[-1].shape,
value=0,
dtype='float32')
pred_shape = fluid.layers.shape(pred_i[-1])
zeros = fluid.layers.fill_constant(
shape=pred_shape, value=0, dtype='float32')
if isinstance(pred_i, list):
pred_i = pred_i[-1]
minval = fluid.layers.elementwise_min(pred_i - 1, zeros)
......@@ -298,7 +294,7 @@ class SPADE(object):
name='input_fake', shape=data_shape, dtype='float32')
# used for continuous evaluation
if self.cfg.enable_ce:
fluid.default_startup_program().random_seed = 90
fluid.default_startup_program().random_seed = 90
gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg,
self.batch_num)
......@@ -348,7 +344,7 @@ class SPADE(object):
if self.cfg.enable_ce:
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
for epoch_id in range(self.cfg.epoch):
......@@ -422,16 +418,21 @@ class SPADE(object):
A_id2name=self.id2name)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer,
"net_D")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
# used for continuous evaluation
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
print("kpis\tspade_g_loss_gan_card{}\t{}".format(device_num, g_loss_gan[0]))
print("kpis\tspade_g_loss_vgg_card{}\t{}".format(device_num,g_loss_vgg[0]))
print("kpis\tspade_g_loss_feat_card{}\t{}".format(device_num,g_loss_feat[0]))
print("kpis\tspade_d_loss_real_card{}\t{}".format(device_num,d_loss_real[0]))
print("kpis\tspade_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
print("kpis\tspade_Batch_time_cost_card{}\t{}".format(device_num,batch_time))
device_num = fluid.core.get_cuda_device_count(
) if self.cfg.use_gpu else 1
print("kpis\tspade_g_loss_gan_card{}\t{}".format(device_num,
g_loss_gan[0]))
print("kpis\tspade_g_loss_vgg_card{}\t{}".format(device_num,
g_loss_vgg[0]))
print("kpis\tspade_g_loss_feat_card{}\t{}".format(
device_num, g_loss_feat[0]))
print("kpis\tspade_d_loss_real_card{}\t{}".format(
device_num, d_loss_real[0]))
print("kpis\tspade_d_loss_fake_card{}\t{}".format(
device_num, d_loss_fake[0]))
print("kpis\tspade_Batch_time_cost_card{}\t{}".format(
device_num, batch_time))
......@@ -45,11 +45,9 @@ class GTrainer():
self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_fake,
shape=self.pred_fake.shape,
value=1.0,
dtype='float32')
fake_shape = fluid.layers.shape(self.pred_fake)
ones = fluid.layers.fill_constant(
shape=fake_shape, value=1.0, dtype='float32')
self.g_loss_fake = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
......@@ -108,11 +106,9 @@ class DTrainer():
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
#lsgan
elif cfg.gan_mode == "lsgan":
ones = fluid.layers.fill_constant_batch_size_like(
input=self.pred_real,
shape=self.pred_real.shape,
value=1.0,
dtype='float32')
real_shape = fluid.layers.shape(self.pred_real)
ones = fluid.layers.fill_constant(
shape=real_shape, value=1.0, dtype='float32')
self.d_loss_real = fluid.layers.mean(
fluid.layers.square(
fluid.layers.elementwise_sub(
......@@ -149,31 +145,30 @@ class DTrainer():
def gradient_penalty(self, f, real, fake=None, cfg=None, name=None):
def _interpolate(a, b=None):
a_shape = fluid.layers.shape(a)
if b is None:
if cfg.enable_ce:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0, seed=1)
beta = fluid.layers.uniform_random(
shape=a_shape, min=0.0, max=1.0, seed=1)
else:
beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0)
beta = fluid.layers.uniform_random(
shape=a_shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean(
a, dim=list(range(len(a.shape))), keep_dim=True)
a, dim=list(range(len(a.shape))))
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean),
dim=list(range(len(a.shape))),
keep_dim=True)
dim=list(range(len(a.shape))))
b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]]
if cfg.enable_ce:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0, seed=1)
else:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0)
alpha = fluid.layers.uniform_random(
shape=a_shape[0], min=0.0, max=1.0, seed=1)
else:
alpha = fluid.layers.uniform_random(
shape=a_shape[0], min=0.0, max=1.0)
inner = fluid.layers.elementwise_mul((b-a), alpha, axis=0) + a
inner = fluid.layers.elementwise_mul((b - a), alpha, axis=0) + a
return inner
x = _interpolate(real, fake)
......@@ -221,7 +216,10 @@ class STGAN(object):
default=1024,
help="the base fc dim in discriminator")
parser.add_argument(
'--use_gru', type=ast.literal_eval, default=True, help="whether to use GRU")
'--use_gru',
type=ast.literal_eval,
default=True,
help="whether to use GRU")
parser.add_argument(
'--lambda_cls',
type=float,
......@@ -345,7 +343,7 @@ class STGAN(object):
if self.cfg.enable_ce:
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
total_train_batch = 0 # used for benchmark
......@@ -353,7 +351,7 @@ class STGAN(object):
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
# optimize the discriminator network
......@@ -389,7 +387,7 @@ class STGAN(object):
sys.stdout.flush()
batch_id += 1
if self.cfg.enable_ce and batch_id == 100:
break
break
total_train_batch += 1 # used for benchmark
# profiler tools
......@@ -418,19 +416,27 @@ class STGAN(object):
test_loader)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer,
"net_D")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
# used for continuous evaluation
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
print("kpis\tstgan_g_loss_fake_card{}\t{}".format(device_num, g_loss_fake[0]))
print("kpis\tstgan_g_loss_rec_card{}\t{}".format(device_num, g_loss_rec[0]))
print("kpis\tstgan_g_loss_cls_card{}\t{}".format(device_num, g_loss_cls[0]))
print("kpis\tstgan_d_loss_card{}\t{}".format(device_num, d_loss[0]))
print("kpis\tstgan_d_loss_real_card{}\t{}".format(device_num, d_loss_real[0]))
print("kpis\tstgan_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
print("kpis\tstgan_d_loss_cls_card{}\t{}".format(device_num, d_loss_cls[0]))
print("kpis\tstgan_d_loss_gp_card{}\t{}".format(device_num,d_loss_gp[0]))
print("kpis\tstgan_Batch_time_cost_card{}\t{}".format(device_num,batch_time))
device_num = fluid.core.get_cuda_device_count(
) if self.cfg.use_gpu else 1
print("kpis\tstgan_g_loss_fake_card{}\t{}".format(
device_num, g_loss_fake[0]))
print("kpis\tstgan_g_loss_rec_card{}\t{}".format(device_num,
g_loss_rec[0]))
print("kpis\tstgan_g_loss_cls_card{}\t{}".format(device_num,
g_loss_cls[0]))
print("kpis\tstgan_d_loss_card{}\t{}".format(device_num, d_loss[
0]))
print("kpis\tstgan_d_loss_real_card{}\t{}".format(
device_num, d_loss_real[0]))
print("kpis\tstgan_d_loss_fake_card{}\t{}".format(
device_num, d_loss_fake[0]))
print("kpis\tstgan_d_loss_cls_card{}\t{}".format(device_num,
d_loss_cls[0]))
print("kpis\tstgan_d_loss_gp_card{}\t{}".format(device_num,
d_loss_gp[0]))
print("kpis\tstgan_Batch_time_cost_card{}\t{}".format(
device_num, batch_time))
......@@ -149,15 +149,17 @@ class DTrainer():
def gradient_penalty(self, f, real, fake, cfg=None, name=None):
def _interpolate(a, b):
shape = [a.shape[0]]
a_shape = fluid.layers.shape(a)
if cfg.enable_ce:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0, seed=1)
else:
alpha = fluid.layers.uniform_random_batch_size_like(
input=a, shape=shape, min=0.0, max=1.0)
alpha = fluid.layers.uniform_random(
shape=[a_shape[0]], min=0.0, max=1.0, seed=1)
else:
alpha = fluid.layers.uniform_random(
shape=[a_shape[0]], min=0.0, max=1.0)
inner = fluid.layers.elementwise_mul(b, (1.0-alpha), axis=0) + fluid.layers.elementwise_mul(a, alpha, axis=0)
inner = fluid.layers.elementwise_mul(
b, (1.0 - alpha), axis=0) + fluid.layers.elementwise_mul(
a, alpha, axis=0)
return inner
x = _interpolate(real, fake)
......@@ -316,7 +318,7 @@ class StarGAN(object):
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
......@@ -355,7 +357,7 @@ class StarGAN(object):
batch_id += 1
# used for ce
if self.cfg.enable_ce and batch_id == 100:
break
break
total_train_batch += 1 # used for benchmark
# profiler tools
......@@ -380,22 +382,28 @@ class StarGAN(object):
if self.cfg.use_gpu else fluid.cpu_places())
test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer,
test_loader)
test_program, gen_trainer, test_loader)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer,
"net_D")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, dis_trainer, "net_D")
# used for continuous evaluation
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
print("kpis\tstargan_g_loss_fake_card{}\t{}".format(device_num, g_loss_fake[0]))
print("kpis\tstargan_g_loss_rec_card{}\t{}".format(device_num, g_loss_rec[0]))
print("kpis\tstargan_g_loss_cls_card{}\t{}".format(device_num, g_loss_cls[0]))
print("kpis\tstargan_d_loss_real_card{}\t{}".format(device_num, d_loss_real[0]))
print("kpis\tstargan_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
print("kpis\tstargan_d_loss_cls_card{}\t{}".format(device_num, d_loss_cls[0]))
print("kpis\tstargan_d_loss_gp_card{}\t{}".format(device_num,d_loss_gp[0]))
print("kpis\tstargan_Batch_time_cost_card{}\t{}".format(device_num,batch_time))
device_num = fluid.core.get_cuda_device_count(
) if self.cfg.use_gpu else 1
print("kpis\tstargan_g_loss_fake_card{}\t{}".format(
device_num, g_loss_fake[0]))
print("kpis\tstargan_g_loss_rec_card{}\t{}".format(
device_num, g_loss_rec[0]))
print("kpis\tstargan_g_loss_cls_card{}\t{}".format(
device_num, g_loss_cls[0]))
print("kpis\tstargan_d_loss_real_card{}\t{}".format(
device_num, d_loss_real[0]))
print("kpis\tstargan_d_loss_fake_card{}\t{}".format(
device_num, d_loss_fake[0]))
print("kpis\tstargan_d_loss_cls_card{}\t{}".format(
device_num, d_loss_cls[0]))
print("kpis\tstargan_d_loss_gp_card{}\t{}".format(device_num,
d_loss_gp[0]))
print("kpis\tstargan_Batch_time_cost_card{}\t{}".format(
device_num, batch_time))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册