未验证 提交 93c4daa4 编写于 作者: Y Yiqun Liu 提交者: GitHub

Calculate the average time for gan models when benchmarking. (#4873)

上级 b9b8c888
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model
from util import utility
from util import timer
import paddle.fluid as fluid
from paddle.fluid import profiler
import paddle
......@@ -291,15 +292,17 @@ class CycleGAN(object):
loss_name=d_B_trainer.d_loss_B.name,
build_strategy=build_strategy)
t_time = 0
total_train_batch = 0 # NOTE :used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch):
batch_id = 0
batch_start = time.time()
for data_A, data_B in zip(A_loader(), B_loader()):
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
reader_cost_averager.record(time.time() - batch_start)
tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
## optimize the g_A network
g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
......@@ -335,26 +338,31 @@ class CycleGAN(object):
feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0]
batch_time = time.time() - s_time
t_time += batch_time
batch_cost_averager.record(time.time() - batch_start)
if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\
d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\
Batch_time_cost: {}".format(
epoch_id, batch_id, d_A_loss[0], g_A_loss[0],
g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0], g_B_loss[
0], g_B_cyc_loss[0], g_B_idt_loss[0], batch_time))
reader_cost: {}, Batch_time_cost: {}"
.format(epoch_id, batch_id, d_A_loss[0], g_A_loss[
0], g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0],
g_B_loss[0], g_B_cyc_loss[0], g_B_idt_loss[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush()
batch_id += 1
#NOTE: used for benchmark
total_train_batch += 1 # used for benchmark
batch_start = time.time()
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
# used for continuous evaluation
if self.cfg.enable_ce and batch_id == 10:
break
......@@ -398,12 +406,9 @@ class CycleGAN(object):
B_id2name=self.B_id2name)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer,
"net_G")
utility.checkpoints(epoch_id, self.cfg, d_A_trainer,
"net_DA")
utility.checkpoints(epoch_id, self.cfg, d_B_trainer,
"net_DB")
utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
utility.checkpoints(epoch_id, self.cfg, d_A_trainer, "net_DA")
utility.checkpoints(epoch_id, self.cfg, d_B_trainer, "net_DB")
# used for continuous evaluation
if self.cfg.enable_ce:
......
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model
from util import utility
from util import timer
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
......@@ -257,16 +258,16 @@ class Pix2pix(object):
loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy)
t_time = 0
total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch):
batch_id = 0
batch_start = time.time()
for tensor in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
reader_cost_averager.record(time.time() - batch_start)
# optimize the generator network
g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
......@@ -291,19 +292,24 @@ class Pix2pix(object):
],
feed=tensor)
batch_time = time.time() - s_time
t_time += batch_time
batch_cost_averager.record(time.time() - batch_start)
if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
g_loss_gan: {}; g_loss_l1: {}; \n\
d_loss_real: {}; d_loss_fake: {}; \n\
Batch_time_cost: {}"
reader_cost: {}, Batch_time_cost: {}"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[
0], d_loss_real[0], d_loss_fake[0], batch_time))
0], d_loss_real[0], d_loss_fake[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
batch_start = time.time()
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
......
......@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.STGAN_network import STGAN_model
from util import utility
from util import timer
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
......@@ -344,16 +346,17 @@ class STGAN(object):
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch):
batch_id = 0
batch_start = time.time()
for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
reader_cost_averager.record(time.time() - batch_start)
# optimize the discriminator network
fetches = [
dis_trainer.d_loss.name,
......@@ -376,20 +379,27 @@ class STGAN(object):
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
.format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time
t_time += batch_time
batch_cost_averager.record(time.time() - batch_start)
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0],
d_loss_gp[0], batch_time))
reader_cost: {}, Batch_time_cost: {}"
.format(epoch_id, batch_id, d_loss[0], d_loss_real[0],
d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
batch_start = time.time()
if self.cfg.enable_ce and batch_id == 100:
break
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
......
......@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from network.StarGAN_network import StarGAN_model
from util import utility
from util import timer
import paddle.fluid as fluid
from paddle.fluid import profiler
import sys
......@@ -313,14 +315,17 @@ class StarGAN(object):
gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90
t_time = 0
total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch):
batch_id = 0
batch_start = time.time()
for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
reader_cost_averager.record(time.time() - batch_start)
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
dis_trainer_program,
fetch_list=[
......@@ -344,22 +349,27 @@ class StarGAN(object):
.format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time
t_time += batch_time
batch_cost_averager.record(time.time() - batch_start)
if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\
d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(
epoch_id, batch_id, d_loss_real[0], d_loss_fake[
0], d_loss_cls[0], d_loss_gp[0], batch_time))
reader_cost: {}, Batch_time_cost: {}"
.format(epoch_id, batch_id, d_loss_real[0],
d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush()
batch_id += 1
total_train_batch += 1 # used for benchmark
batch_start = time.time()
# used for ce
if self.cfg.enable_ce and batch_id == 100:
break
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册