提交 62e824db 编写于 作者: H hysunflower 提交者: Jinhua Liang

add_for_cyclegan_models (#3997)

上级 c8545d66
...@@ -69,6 +69,11 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.") ...@@ -69,6 +69,11 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.")
add_arg('run_test', bool, True, "Whether to run test.") add_arg('run_test', bool, True, "Whether to run test.")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('profile', bool, False, "Whether to profile.") add_arg('profile', bool, False, "Whether to profile.")
# NOTE: args for profiler, used for benchmark
add_arg('profiler_path', str, './profiler_cyclegan', "the path of profiler output files. used for benchmark")
add_arg('max_iter', int, 0, "the max batch nums to train. used for benchmark")
add_arg('run_ce', bool, False, "Whether to run for model ce.") add_arg('run_ce', bool, False, "Whether to run for model ce.")
# yapf: enable # yapf: enable
...@@ -214,9 +219,14 @@ def train(args): ...@@ -214,9 +219,14 @@ def train(args):
loss_name=d_A_trainer.d_loss_A.name, loss_name=d_A_trainer.d_loss_A.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
total_batch_num = 0 # this is for benchmark
for epoch in range(args.epoch): for epoch in range(args.epoch):
batch_id = 0 batch_id = 0
for i in range(max_images_num): for i in range(max_images_num):
if args.max_iter and total_batch_num == args.max_iter: # this for benchmark
return
data_A = next(A_reader) data_A = next(A_reader)
data_B = next(B_reader) data_B = next(B_reader)
tensor_A = fluid.LoDTensor() tensor_A = fluid.LoDTensor()
...@@ -265,6 +275,12 @@ def train(args): ...@@ -265,6 +275,12 @@ def train(args):
losses[1].append(d_A_loss[0]) losses[1].append(d_A_loss[0])
sys.stdout.flush() sys.stdout.flush()
batch_id += 1 batch_id += 1
total_batch_num = total_batch_num + 1 # this is for benchmark
# profiler tools for benchmark
if args.profile and epoch == 0 and batch_id == 10:
profiler.reset_profiler()
elif args.profile and epoch == 0 and batch_id == 15:
return
if args.run_test and not args.run_ce: if args.run_test and not args.run_ce:
test(epoch) test(epoch)
...@@ -281,7 +297,7 @@ if __name__ == "__main__": ...@@ -281,7 +297,7 @@ if __name__ == "__main__":
print_arguments(args) print_arguments(args)
if args.profile: if args.profile:
if args.use_gpu: if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof: with profiler.profiler('All', 'total', args.profiler_path) as prof:
train(args) train(args)
else: else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof: with profiler.profiler("CPU", sorted_key='total') as cpuprof:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册