提交 bb48a52c 编写于 作者: H hysunflower 提交者: Jinhua Liang

Add maxiter and profiler for padding models (#3947)

* add profiler and max_iter for paddingrnn

* add profiler for paddingrnn
上级 097bad64
......@@ -80,5 +80,8 @@ def parse_args():
parser.add_argument('--enable_ce', action='store_true')
parser.add_argument('--batch_size', type=int, default=0, help='batch size')
parser.add_argument('--max_epoch', type=int, default=0, help='max epoch')
# NOTE: args for profiler, used for benchmark
parser.add_argument('--profiler_path', type=str, default='/tmp/paddingrnn.profile', help='the profiler output file path. used for benchmark')
args = parser.parse_args()
return args
......@@ -25,6 +25,7 @@ import contextlib
from distutils.dir_util import mkpath
import paddle
import paddle.fluid as fluid
from paddle.fluid import profiler
import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor
......@@ -50,9 +51,9 @@ SEED = 123
@contextlib.contextmanager
def profile_context(profile=True):
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
if profile:
with profiler.profiler('All', 'total', '/tmp/paddingrnn.profile'):
with profiler.profiler('All', 'total', profiler_path):
yield
else:
yield
......@@ -318,6 +319,12 @@ def main():
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
# profiler tools for benchmark
if args.profile and batch_id == log_interval:
profiler.reset_profiler()
elif args.profile and batch_id == (log_interval + 5):
break
ppl = np.exp(total_loss / iters)
return ppl
......@@ -371,6 +378,11 @@ def main():
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
batch_id += 1
# profiler tools for benchmark
if args.profile and batch_id == log_interval:
profiler.reset_profiler()
elif args.profile and batch_id == (log_interval + 5):
break
except fluid.core.EOFException:
dataloader.reset()
......@@ -455,7 +467,7 @@ def main():
fluid.save(main_program, save_model_dir)
print("Saved model to: %s.\n" % save_model_dir)
with profile_context(args.profile):
with profile_context(args.profile, args.profiler_path):
train()
test_ppl = eval(test_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册