提交 62e824db 编写于 作者: H hysunflower 提交者: Jinhua Liang

add_for_cyclegan_models (#3997)

上级 c8545d66
......@@ -69,6 +69,11 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.")
add_arg('run_test', bool, True, "Whether to run test.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('profile', bool, False, "Whether to profile.")
# NOTE: args for profiler, used for benchmark
add_arg('profiler_path', str, './profiler_cyclegan', "the path of profiler output files. used for benchmark")
add_arg('max_iter', int, 0, "the max batch nums to train. used for benchmark")
add_arg('run_ce', bool, False, "Whether to run for model ce.")
# yapf: enable
......@@ -214,9 +219,14 @@ def train(args):
loss_name=d_A_trainer.d_loss_A.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
total_batch_num = 0 # this is for benchmark
for epoch in range(args.epoch):
batch_id = 0
for i in range(max_images_num):
if args.max_iter and total_batch_num == args.max_iter: # this for benchmark
return
data_A = next(A_reader)
data_B = next(B_reader)
tensor_A = fluid.LoDTensor()
......@@ -265,6 +275,12 @@ def train(args):
losses[1].append(d_A_loss[0])
sys.stdout.flush()
batch_id += 1
total_batch_num = total_batch_num + 1 # this is for benchmark
# profiler tools for benchmark
if args.profile and epoch == 0 and batch_id == 10:
profiler.reset_profiler()
elif args.profile and epoch == 0 and batch_id == 15:
return
if args.run_test and not args.run_ce:
test(epoch)
......@@ -281,7 +297,7 @@ if __name__ == "__main__":
print_arguments(args)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
with profiler.profiler('All', 'total', args.profiler_path) as prof:
train(args)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册