提交 b587334f 编写于 作者: u010070587's avatar u010070587 提交者: ceci3

add spade models (#4146)

上级 c1e7c0f8
......@@ -268,7 +268,11 @@ class SPADE(object):
type=bool,
default=False,
help="Whether to use instance label.")
parser.add_argument(
'--enable_ce',
type=bool,
default=False,
help="If set True, enable continuous evaluation job.")
return parser
def __init__(self,
......@@ -298,6 +302,9 @@ class SPADE(object):
name='input_ins', shape=edge_shape, dtype='float32')
input_fake = fluid.data(
name='input_fake', shape=data_shape, dtype='float32')
# used for continuous evaluation
if self.cfg.enable_ce:
fluid.default_startup_program().random_seed = 90
gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg,
self.batch_num)
......@@ -343,7 +350,11 @@ class SPADE(object):
dis_trainer.program).with_data_parallel(
loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy)
# used for continuous evaluation
if self.cfg.enable_ce:
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
for epoch_id in range(self.cfg.epoch):
......@@ -390,7 +401,6 @@ class SPADE(object):
0], batch_time))
sys.stdout.flush()
batch_id += 1
if self.cfg.run_test:
test_program = gen_trainer.infer_program
......@@ -422,3 +432,12 @@ class SPADE(object):
"net_G")
utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer,
"net_D")
# used for continuous evaluation
if self.cfg.enable_ce:
device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1
print("kpis\tspade_g_loss_gan_card{}\t{}".format(device_num, g_loss_gan[0]))
print("kpis\tspade_g_loss_vgg_card{}\t{}".format(device_num,g_loss_vgg[0]))
print("kpis\tspade_g_loss_feat_card{}\t{}".format(device_num,g_loss_feat[0]))
print("kpis\tspade_d_loss_real_card{}\t{}".format(device_num,d_loss_real[0]))
print("kpis\tspade_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0]))
print("kpis\tspade_Batch_time_cost_card{}\t{}".format(device_num,batch_time))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册