未验证 提交 9d18809a 编写于 作者: Y Yiqun Liu 提交者: GitHub

Correct a typo and add the support of profiler. (#2713)

上级 b2a6586d
...@@ -54,6 +54,11 @@ def parse_args(): ...@@ -54,6 +54,11 @@ def parse_args():
type=str2bool, type=str2bool,
default=True, default=True,
help='Whether using gpu in parallel [True|False]') help='Whether using gpu in parallel [True|False]')
parser.add_argument(
'--profile',
type=str2bool,
default=False,
help='Whether profiling the trainning [True|False]')
parser.add_argument( parser.add_argument(
'--use_py_reader', '--use_py_reader',
type=str2bool, type=str2bool,
......
...@@ -20,12 +20,13 @@ import numpy as np ...@@ -20,12 +20,13 @@ import numpy as np
import time import time
import os import os
import random import random
import math import math
import contextlib
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
import reader import reader
...@@ -48,6 +49,15 @@ import pickle ...@@ -48,6 +49,15 @@ import pickle
SEED = 123 SEED = 123
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/paddingrnn.profile'):
yield
else:
yield
def get_current_model_para(train_prog, train_exe): def get_current_model_para(train_prog, train_exe):
param_list = train_prog.block(0).all_parameters() param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list] param_name_list = [p.name for p in param_list]
...@@ -273,8 +283,10 @@ def main(): ...@@ -273,8 +283,10 @@ def main():
batch_start_time = time.time() batch_start_time = time.time()
fetch_outs = exe.run(train_program, fetch_outs = exe.run(train_program,
feed=input_data_feed, feed=input_data_feed,
fetch_list=[loss.name, "learning_rate", \ fetch_list=[
last_hidden.name, last_cell.name ], loss.name, "learning_rate",
last_hidden.name, last_cell.name
],
use_program_cache=True) use_program_cache=True)
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
...@@ -324,14 +336,16 @@ def main(): ...@@ -324,14 +336,16 @@ def main():
fetch_outs = exe.run(train_program, fetch_outs = exe.run(train_program,
feed=data_feeds, feed=data_feeds,
fetch_list=[loss.name, "learning_rate", \ fetch_list=[
last_hidden.name, last_cell.name ], loss.name, "learning_rate",
last_hidden.name, last_cell.name
],
use_program_cache=True) use_program_cache=True)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1]) lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_list[2]) init_hidden = np.array(fetch_outs[2])
init_cell = np.array( fetch_list[3] ) init_cell = np.array(fetch_outs[3])
total_loss += cost_train total_loss += cost_train
iters += config.num_steps iters += config.num_steps
...@@ -424,7 +438,9 @@ def main(): ...@@ -424,7 +438,9 @@ def main():
executor=exe, dirname=save_model_dir, main_program=main_program) executor=exe, dirname=save_model_dir, main_program=main_program)
print("Saved model to: %s.\n" % save_model_dir) print("Saved model to: %s.\n" % save_model_dir)
train() with profile_context(args.profile):
train()
test_ppl = eval(test_data) test_ppl = eval(test_data)
print("Test ppl:", test_ppl[0]) print("Test ppl:", test_ppl[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册