未验证 提交 e99b607e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance the training processing of rnn_rearch (#3168)

* Refine the print and add timing for each step and epoch.

* Enable the profile.
上级 2ef93ad2
......@@ -119,5 +119,15 @@ def parse_args():
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--parallel",
action='store_true',
help="Whether execute with the data_parallel mode.")
parser.add_argument(
"--profile",
action='store_true',
help="Whether enable the profile.")
args = parser.parse_args()
return args
......@@ -20,12 +20,13 @@ import numpy as np
import time
import os
import random
import math
import contextlib
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor
import reader
......@@ -45,7 +46,16 @@ import pickle
SEED = 123
def train():
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', 'seq2seq.profile'):
yield
else:
yield
def main():
args = parse_args()
num_layers = args.num_layers
......@@ -106,6 +116,29 @@ def train():
exe = Executor(place)
exe.run(framework.default_startup_program())
device_count = len(fluid.cuda_places()) if args.use_gpu else len(
fluid.cpu_places())
if device_count > 1:
raise Exception("Training using multi-GPUs is not supported now.")
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = device_count
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
# build_strategy.fuse_all_optimizer_ops = True
if args.parallel:
train_program = fluid.compiler.CompiledProgram(
framework.default_main_program()).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
else:
train_program = framework.default_main_program()
train_data_prefix = args.train_data_prefix
eval_data_prefix = args.eval_data_prefix
test_data_prefix = args.test_data_prefix
......@@ -160,12 +193,12 @@ def train():
return ppl
def train():
ce_time = []
ce_ppl = []
max_epoch = args.max_epoch
for epoch_id in range(max_epoch):
start_time = time.time()
print("epoch id", epoch_id)
if args.enable_ce:
train_data_iter = reader.get_data_iter(train_data, batch_size, enable_ce=True)
else:
......@@ -174,10 +207,12 @@ def train():
total_loss = 0
word_count = 0.0
batch_times = []
for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
input_data_feed, word_num = prepare_input(batch, epoch_id=epoch_id)
fetch_outs = exe.run(feed=input_data_feed,
fetch_outs = exe.run(program=train_program,
feed=input_data_feed,
fetch_list=[loss.name],
use_program_cache=True)
......@@ -185,16 +220,26 @@ def train():
total_loss += cost_train * batch_size
word_count += word_num
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
if batch_id > 0 and batch_id % 100 == 0:
print("ppl", batch_id, np.exp(total_loss / word_count))
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f"
% (epoch_id, batch_id, batch_time, np.exp(total_loss / word_count)))
ce_ppl.append(np.exp(total_loss / word_count))
total_loss = 0.0
word_count = 0.0
end_time = time.time()
time_gap = end_time - start_time
ce_time.append(time_gap)
epoch_time = end_time - start_time
ce_time.append(epoch_time)
print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
if not args.profile:
dir_name = args.model_path + "/epoch_" + str(epoch_id)
print("begin to save", dir_name)
fluid.io.save_params(exe, dir_name)
......@@ -218,6 +263,9 @@ def train():
print("kpis\ttrain_ppl_card%s\t%f" %
(card_num, _ppl))
with profile_context(args.profile):
train()
def get_cards():
num = 0
......@@ -228,4 +276,4 @@ def get_cards():
if __name__ == '__main__':
train()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册