From b587334f24be09e7c1ea24139a930fc4d5ba77c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BE=E6=99=93?= Date: Mon, 6 Jan 2020 14:04:09 +0800 Subject: [PATCH] add spade models (#4146) --- PaddleCV/PaddleGAN/trainer/SPADE.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/PaddleCV/PaddleGAN/trainer/SPADE.py b/PaddleCV/PaddleGAN/trainer/SPADE.py index b11c9b6c..26466e30 100644 --- a/PaddleCV/PaddleGAN/trainer/SPADE.py +++ b/PaddleCV/PaddleGAN/trainer/SPADE.py @@ -268,7 +268,11 @@ class SPADE(object): type=bool, default=False, help="Whether to use instance label.") - + parser.add_argument( + '--enable_ce', + type=bool, + default=False, + help="If set True, enable continuous evaluation job.") return parser def __init__(self, @@ -298,6 +302,9 @@ class SPADE(object): name='input_ins', shape=edge_shape, dtype='float32') input_fake = fluid.data( name='input_fake', shape=data_shape, dtype='float32') + # used for continuous evaluation + if self.cfg.enable_ce: + fluid.default_startup_program().random_seed = 90 gen_trainer = GTrainer(input_A, input_B, input_C, self.cfg, self.batch_num) @@ -343,7 +350,11 @@ class SPADE(object): dis_trainer.program).with_data_parallel( loss_name=dis_trainer.d_loss.name, build_strategy=build_strategy) - + # used for continuous evaluation + if self.cfg.enable_ce: + gen_trainer_program.random_seed = 90 + dis_trainer_program.random_seed = 90 + t_time = 0 for epoch_id in range(self.cfg.epoch): @@ -390,7 +401,6 @@ class SPADE(object): 0], batch_time)) sys.stdout.flush() - batch_id += 1 if self.cfg.run_test: test_program = gen_trainer.infer_program @@ -422,3 +432,12 @@ class SPADE(object): "net_G") utility.checkpoints(epoch_id, self.cfg, exe, dis_trainer, "net_D") + # used for continuous evaluation + if self.cfg.enable_ce: + device_num = fluid.core.get_cuda_device_count() if self.cfg.use_gpu else 1 + print("kpis\tspade_g_loss_gan_card{}\t{}".format(device_num, g_loss_gan[0])) + print("kpis\tspade_g_loss_vgg_card{}\t{}".format(device_num,g_loss_vgg[0])) + print("kpis\tspade_g_loss_feat_card{}\t{}".format(device_num,g_loss_feat[0])) + print("kpis\tspade_d_loss_real_card{}\t{}".format(device_num,d_loss_real[0])) + print("kpis\tspade_d_loss_fake_card{}\t{}".format(device_num,d_loss_fake[0])) + print("kpis\tspade_Batch_time_cost_card{}\t{}".format(device_num,batch_time)) -- GitLab