未验证 提交 9d18809a 编写于 作者: Y Yiqun Liu 提交者: GitHub

Correct a typo and add the support of profiler. (#2713)

上级 b2a6586d
......@@ -54,6 +54,11 @@ def parse_args():
type=str2bool,
default=True,
help='Whether using gpu in parallel [True|False]')
parser.add_argument(
'--profile',
type=str2bool,
default=False,
help='Whether profiling the trainning [True|False]')
parser.add_argument(
'--use_py_reader',
type=str2bool,
......
......@@ -20,12 +20,13 @@ import numpy as np
import time
import os
import random
import math
import contextlib
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor
import reader
......@@ -48,6 +49,15 @@ import pickle
SEED = 123
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/paddingrnn.profile'):
yield
else:
yield
def get_current_model_para(train_prog, train_exe):
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
......@@ -273,8 +283,10 @@ def main():
batch_start_time = time.time()
fetch_outs = exe.run(train_program,
feed=input_data_feed,
fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
fetch_list=[
loss.name, "learning_rate",
last_hidden.name, last_cell.name
],
use_program_cache=True)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
......@@ -324,14 +336,16 @@ def main():
fetch_outs = exe.run(train_program,
feed=data_feeds,
fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
fetch_list=[
loss.name, "learning_rate",
last_hidden.name, last_cell.name
],
use_program_cache=True)
cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_list[2])
init_cell = np.array( fetch_list[3] )
init_hidden = np.array(fetch_outs[2])
init_cell = np.array(fetch_outs[3])
total_loss += cost_train
iters += config.num_steps
......@@ -424,7 +438,9 @@ def main():
executor=exe, dirname=save_model_dir, main_program=main_program)
print("Saved model to: %s.\n" % save_model_dir)
train()
with profile_context(args.profile):
train()
test_ppl = eval(test_data)
print("Test ppl:", test_ppl[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册