提交 097bad64 编写于 作者: H hysunflower 提交者: Jinhua Liang

update the profiler of seq2seq (#3959)

* update the profiler of seq2seq

* modify profiler args
上级 2d5dd1d8
......@@ -123,15 +123,10 @@ def parse_args():
parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.")
# NOTE: profiler args, used for benchmark
parser.add_argument(
"--is_profiler",
type=int,
default=0,
help="The switch of profiler tools. (used for benchmark)")
parser.add_argument(
"--profiler_path",
type=str,
default='./',
default='./seq2seq.profile',
help="the profiler output file path. (used for benchmark)")
args = parser.parse_args()
return args
......@@ -47,9 +47,9 @@ import pickle
@contextlib.contextmanager
def profile_context(profile=True):
def profile_context(profile=True, profiler_path='./seq2seq.profile'):
if profile:
with profiler.profiler('All', 'total', 'seq2seq.profile'):
with profiler.profiler('All', 'total', profiler_path):
yield
else:
yield
......@@ -216,10 +216,9 @@ def main():
word_count = 0.0
# profiler tools
if args.is_profiler and epoch_id == 0 and batch_id == 100:
profiler.start_profiler("All")
elif args.is_profiler and epoch_id == 0 and batch_id == 105:
profiler.stop_profiler("total", args.profiler_path)
if args.profile and epoch_id == 0 and batch_id == 100:
profiler.reset_profiler()
elif args.profile and epoch_id == 0 and batch_id == 105:
return
end_time = time.time()
......@@ -252,7 +251,7 @@ def main():
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
with profile_context(args.profile):
with profile_context(args.profile, args.profiler_path):
train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册