提交 1d63dafd 编写于 作者: G gongweibao

fix

上级 bb696e5b
......@@ -32,6 +32,7 @@ class TrainTaskConfig(object):
start_step = 0
# the frequency to save trained models.
save_freq = 10000
profile=True
class InferTaskConfig(object):
......
......@@ -7,6 +7,7 @@ import time
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import reader
from config import *
......@@ -130,6 +131,9 @@ def parse_args():
default=100,
help="Fetch outputs steps.")
#parser.add_argument(
# '--profile', action='store_true', help='If set, profile a few steps.')
args = parser.parse_args()
# Append args related to dict
......@@ -467,8 +471,8 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
#build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
exec_strategy = fluid.ExecutionStrategy()
if args.update_method == "nccl2":
exec_strategy.num_threads = 1
#if args.update_method == "nccl2":
exec_strategy.num_threads = 1
logging.info("begin executor")
train_exe = fluid.ParallelExecutor(
......@@ -509,11 +513,22 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
feed_dict_list = prepare_feed_dict_list(data_generator,
init_flag, dev_count)
if TrainTaskConfig.profile and batch_id == 5:
logging.info("begin profiler")
profiler.start_profiler("All")
profiler.reset_profiler()
elif TrainTaskConfig.profile and batch_id == 10:
logging.info("end profiler")
#logging.info("profiling total time: ", time.time() - start_time)
profiler.stop_profiler("total", "./transformer_local_profile_{}_pass{}".format(batch_id, pass_id))
sys.exit(0)
logging.info("batch_id:{}".format(batch_id))
outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name] if batch_id % args.fetch_steps == 0 else[],
fetch_list=[sum_cost.name, token_num.name] if (batch_id % args.fetch_steps == 0 or TrainTaskConfig.profile) else[],
feed=feed_dict_list)
if batch_id % args.fetch_steps == 0 and batch_id > 0:
if (batch_id % args.fetch_steps == 0 and batch_id > 0):
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
# sum the cost from multi-devices
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册