未验证 提交 93c4daa4 编写于 作者: Y Yiqun Liu 提交者: GitHub

Calculate the average time for gan models when benchmarking. (#4873)

上级 b9b8c888
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model from network.CycleGAN_network import CycleGAN_model
from util import utility from util import utility
from util import timer
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
import paddle import paddle
...@@ -291,15 +292,17 @@ class CycleGAN(object): ...@@ -291,15 +292,17 @@ class CycleGAN(object):
loss_name=d_B_trainer.d_loss_B.name, loss_name=d_B_trainer.d_loss_B.name,
build_strategy=build_strategy) build_strategy=build_strategy)
t_time = 0
total_train_batch = 0 # NOTE :used for benchmark total_train_batch = 0 # NOTE :used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
batch_start = time.time()
for data_A, data_B in zip(A_loader(), B_loader()): for data_A, data_B in zip(A_loader(), B_loader()):
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return return
s_time = time.time() reader_cost_averager.record(time.time() - batch_start)
tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B'] tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
## optimize the g_A network ## optimize the g_A network
g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\ g_A_loss, g_A_cyc_loss, g_A_idt_loss, g_B_loss, g_B_cyc_loss,\
...@@ -335,26 +338,31 @@ class CycleGAN(object): ...@@ -335,26 +338,31 @@ class CycleGAN(object):
feed={"input_A": tensor_A, feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0] "fake_pool_A": fake_pool_A})[0]
batch_time = time.time() - s_time batch_cost_averager.record(time.time() - batch_start)
t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\ d_A_loss: {}; g_A_loss: {}; g_A_cyc_loss: {}; g_A_idt_loss: {}; \n\
d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\ d_B_loss: {}; g_B_loss: {}; g_B_cyc_loss: {}; g_B_idt_loss: {}; \n\
Batch_time_cost: {}".format( reader_cost: {}, Batch_time_cost: {}"
epoch_id, batch_id, d_A_loss[0], g_A_loss[0], .format(epoch_id, batch_id, d_A_loss[0], g_A_loss[
g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0], g_B_loss[ 0], g_A_cyc_loss[0], g_A_idt_loss[0], d_B_loss[0],
0], g_B_cyc_loss[0], g_B_idt_loss[0], batch_time)) g_B_loss[0], g_B_cyc_loss[0], g_B_idt_loss[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
#NOTE: used for benchmark
total_train_batch += 1 # used for benchmark total_train_batch += 1 # used for benchmark
batch_start = time.time()
# profiler tools # profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq: if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler() profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5: elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return return
# used for continuous evaluation # used for continuous evaluation
if self.cfg.enable_ce and batch_id == 10: if self.cfg.enable_ce and batch_id == 10:
break break
...@@ -398,12 +406,9 @@ class CycleGAN(object): ...@@ -398,12 +406,9 @@ class CycleGAN(object):
B_id2name=self.B_id2name) B_id2name=self.B_id2name)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, gen_trainer, utility.checkpoints(epoch_id, self.cfg, gen_trainer, "net_G")
"net_G") utility.checkpoints(epoch_id, self.cfg, d_A_trainer, "net_DA")
utility.checkpoints(epoch_id, self.cfg, d_A_trainer, utility.checkpoints(epoch_id, self.cfg, d_B_trainer, "net_DB")
"net_DA")
utility.checkpoints(epoch_id, self.cfg, d_B_trainer,
"net_DB")
# used for continuous evaluation # used for continuous evaluation
if self.cfg.enable_ce: if self.cfg.enable_ce:
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from network.Pix2pix_network import Pix2pix_model from network.Pix2pix_network import Pix2pix_model
from util import utility from util import utility
from util import timer
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
import sys import sys
...@@ -257,16 +258,16 @@ class Pix2pix(object): ...@@ -257,16 +258,16 @@ class Pix2pix(object):
loss_name=dis_trainer.d_loss.name, loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy) build_strategy=build_strategy)
t_time = 0
total_train_batch = 0 # used for benchmark total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
batch_start = time.time()
for tensor in loader(): for tensor in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return return
s_time = time.time() reader_cost_averager.record(time.time() - batch_start)
# optimize the generator network # optimize the generator network
g_loss_gan, g_loss_l1, fake_B_tmp = exe.run( g_loss_gan, g_loss_l1, fake_B_tmp = exe.run(
...@@ -291,19 +292,24 @@ class Pix2pix(object): ...@@ -291,19 +292,24 @@ class Pix2pix(object):
], ],
feed=tensor) feed=tensor)
batch_time = time.time() - s_time batch_cost_averager.record(time.time() - batch_start)
t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if batch_id % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
g_loss_gan: {}; g_loss_l1: {}; \n\ g_loss_gan: {}; g_loss_l1: {}; \n\
d_loss_real: {}; d_loss_fake: {}; \n\ d_loss_real: {}; d_loss_fake: {}; \n\
Batch_time_cost: {}" reader_cost: {}, Batch_time_cost: {}"
.format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[ .format(epoch_id, batch_id, g_loss_gan[0], g_loss_l1[
0], d_loss_real[0], d_loss_fake[0], batch_time)) 0], d_loss_real[0], d_loss_fake[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark total_train_batch += 1 # used for benchmark
batch_start = time.time()
# profiler tools # profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq: if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler() profiler.reset_profiler()
......
...@@ -11,11 +11,13 @@ ...@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from network.STGAN_network import STGAN_model from network.STGAN_network import STGAN_model
from util import utility from util import utility
from util import timer
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
import sys import sys
...@@ -344,16 +346,17 @@ class STGAN(object): ...@@ -344,16 +346,17 @@ class STGAN(object):
gen_trainer_program.random_seed = 90 gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90 dis_trainer_program.random_seed = 90
t_time = 0
total_train_batch = 0 # used for benchmark total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
batch_start = time.time()
for data in loader(): for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return return
s_time = time.time() reader_cost_averager.record(time.time() - batch_start)
# optimize the discriminator network # optimize the discriminator network
fetches = [ fetches = [
dis_trainer.d_loss.name, dis_trainer.d_loss.name,
...@@ -376,20 +379,27 @@ class STGAN(object): ...@@ -376,20 +379,27 @@ class STGAN(object):
g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}" g_loss_fake: {}; g_loss_rec: {}; g_loss_cls: {}"
.format(epoch_id, batch_id, g_loss_fake[0], .format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0])) g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time
t_time += batch_time batch_cost_averager.record(time.time() - batch_start)
if (batch_id + 1) % self.cfg.print_freq == 0: if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\ d_loss: {}; d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format(epoch_id, batch_id, d_loss[ reader_cost: {}, Batch_time_cost: {}"
0], d_loss_real[0], d_loss_fake[0], d_loss_cls[0], .format(epoch_id, batch_id, d_loss[0], d_loss_real[0],
d_loss_gp[0], batch_time)) d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark
batch_start = time.time()
if self.cfg.enable_ce and batch_id == 100: if self.cfg.enable_ce and batch_id == 100:
break break
total_train_batch += 1 # used for benchmark
# profiler tools # profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq: if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler() profiler.reset_profiler()
......
...@@ -11,11 +11,13 @@ ...@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from network.StarGAN_network import StarGAN_model from network.StarGAN_network import StarGAN_model
from util import utility from util import utility
from util import timer
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
import sys import sys
...@@ -313,14 +315,17 @@ class StarGAN(object): ...@@ -313,14 +315,17 @@ class StarGAN(object):
gen_trainer_program.random_seed = 90 gen_trainer_program.random_seed = 90
dis_trainer_program.random_seed = 90 dis_trainer_program.random_seed = 90
t_time = 0
total_train_batch = 0 # used for benchmark total_train_batch = 0 # used for benchmark
reader_cost_averager = timer.TimeAverager()
batch_cost_averager = timer.TimeAverager()
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
batch_start = time.time()
for data in loader(): for data in loader():
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return return
s_time = time.time() reader_cost_averager.record(time.time() - batch_start)
d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run( d_loss_real, d_loss_fake, d_loss, d_loss_cls, d_loss_gp = exe.run(
dis_trainer_program, dis_trainer_program,
fetch_list=[ fetch_list=[
...@@ -344,22 +349,27 @@ class StarGAN(object): ...@@ -344,22 +349,27 @@ class StarGAN(object):
.format(epoch_id, batch_id, g_loss_fake[0], .format(epoch_id, batch_id, g_loss_fake[0],
g_loss_rec[0], g_loss_cls[0])) g_loss_rec[0], g_loss_cls[0]))
batch_time = time.time() - s_time batch_cost_averager.record(time.time() - batch_start)
t_time += batch_time
if (batch_id + 1) % self.cfg.print_freq == 0: if (batch_id + 1) % self.cfg.print_freq == 0:
print("epoch{}: batch{}: \n\ print("epoch{}: batch{}: \n\
d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\ d_loss_real: {}; d_loss_fake: {}; d_loss_cls: {}; d_loss_gp: {} \n\
Batch_time_cost: {}".format( reader_cost: {}, Batch_time_cost: {}"
epoch_id, batch_id, d_loss_real[0], d_loss_fake[ .format(epoch_id, batch_id, d_loss_real[0],
0], d_loss_cls[0], d_loss_gp[0], batch_time)) d_loss_fake[0], d_loss_cls[0], d_loss_gp[0],
reader_cost_averager.get_average(),
batch_cost_averager.get_average()))
reader_cost_averager.reset()
batch_cost_averager.reset()
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_train_batch += 1 # used for benchmark
batch_start = time.time()
# used for ce # used for ce
if self.cfg.enable_ce and batch_id == 100: if self.cfg.enable_ce and batch_id == 100:
break break
total_train_batch += 1 # used for benchmark
# profiler tools # profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq: if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler() profiler.reset_profiler()
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
def record(self, usetime):
self._cnt += 1
self._total_time += usetime
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册