未验证 提交 42d56927 编写于 作者: H hysunflower 提交者: GitHub

add_profiler_max_iter_cyclegan (#4519)

上级 963951e8
......@@ -18,6 +18,7 @@ from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model
from util import utility
import paddle.fluid as fluid
from paddle.fluid import profiler
import paddle
import sys
import time
......@@ -292,9 +293,12 @@ class CycleGAN(object):
t_time = 0
total_train_batch = 0 # NOTE :used for benchmark
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for data_A, data_B in zip(A_loader(), B_loader()):
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time()
tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
## optimize the g_A network
......@@ -344,6 +348,13 @@ class CycleGAN(object):
sys.stdout.flush()
batch_id += 1
#NOTE: used for benchmark
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
# used for continuous evaluation
if self.cfg.enable_ce and batch_id == 10:
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册