提交 097bad64 编写于 作者: H hysunflower 提交者: Jinhua Liang

update the profiler of seq2seq (#3959)

* update the profiler of seq2seq

* modify profiler args
上级 2d5dd1d8
...@@ -123,15 +123,10 @@ def parse_args(): ...@@ -123,15 +123,10 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.") "--profile", action='store_true', help="Whether enable the profile.")
# NOTE: profiler args, used for benchmark # NOTE: profiler args, used for benchmark
parser.add_argument(
"--is_profiler",
type=int,
default=0,
help="The switch of profiler tools. (used for benchmark)")
parser.add_argument( parser.add_argument(
"--profiler_path", "--profiler_path",
type=str, type=str,
default='./', default='./seq2seq.profile',
help="the profiler output file path. (used for benchmark)") help="the profiler output file path. (used for benchmark)")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -47,9 +47,9 @@ import pickle ...@@ -47,9 +47,9 @@ import pickle
@contextlib.contextmanager @contextlib.contextmanager
def profile_context(profile=True): def profile_context(profile=True, profiler_path='./seq2seq.profile'):
if profile: if profile:
with profiler.profiler('All', 'total', 'seq2seq.profile'): with profiler.profiler('All', 'total', profiler_path):
yield yield
else: else:
yield yield
...@@ -216,10 +216,9 @@ def main(): ...@@ -216,10 +216,9 @@ def main():
word_count = 0.0 word_count = 0.0
# profiler tools # profiler tools
if args.is_profiler and epoch_id == 0 and batch_id == 100: if args.profile and epoch_id == 0 and batch_id == 100:
profiler.start_profiler("All") profiler.reset_profiler()
elif args.is_profiler and epoch_id == 0 and batch_id == 105: elif args.profile and epoch_id == 0 and batch_id == 105:
profiler.stop_profiler("total", args.profiler_path)
return return
end_time = time.time() end_time = time.time()
...@@ -252,7 +251,7 @@ def main(): ...@@ -252,7 +251,7 @@ def main():
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl)) print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
with profile_context(args.profile): with profile_context(args.profile, args.profiler_path):
train() train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册