未验证 提交 42d56927 编写于 作者: H hysunflower 提交者: GitHub

add_profiler_max_iter_cyclegan (#4519)

上级 963951e8
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from network.CycleGAN_network import CycleGAN_model from network.CycleGAN_network import CycleGAN_model
from util import utility from util import utility
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler
import paddle import paddle
import sys import sys
import time import time
...@@ -292,9 +293,12 @@ class CycleGAN(object): ...@@ -292,9 +293,12 @@ class CycleGAN(object):
t_time = 0 t_time = 0
total_train_batch = 0 # NOTE :used for benchmark
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for data_A, data_B in zip(A_loader(), B_loader()): for data_A, data_B in zip(A_loader(), B_loader()):
if self.cfg.max_iter and total_train_batch == self.cfg.max_iter: # used for benchmark
return
s_time = time.time() s_time = time.time()
tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B'] tensor_A, tensor_B = data_A[0]['input_A'], data_B[0]['input_B']
## optimize the g_A network ## optimize the g_A network
...@@ -344,6 +348,13 @@ class CycleGAN(object): ...@@ -344,6 +348,13 @@ class CycleGAN(object):
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
#NOTE: used for benchmark
total_train_batch += 1 # used for benchmark
# profiler tools
if self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq:
profiler.reset_profiler()
elif self.cfg.profile and epoch_id == 0 and batch_id == self.cfg.print_freq + 5:
return
# used for continuous evaluation # used for continuous evaluation
if self.cfg.enable_ce and batch_id == 10: if self.cfg.enable_ce and batch_id == 10:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册