未验证 提交 8a31b1ca 编写于 作者: H hong 提交者: GitHub

add tokens per sec; test=develop (#4875)

上级 fd2ff205
...@@ -97,8 +97,7 @@ def main(): ...@@ -97,8 +97,7 @@ def main():
dropout=dropout) dropout=dropout)
loss = model.build_graph() loss = model.build_graph()
inference_program = train_program.clone(for_test=True) inference_program = train_program.clone(for_test=True)
clip=fluid.clip.GradientClipByGlobalNorm( clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=max_grad_norm)
clip_norm=max_grad_norm)
lr = args.learning_rate lr = args.learning_rate
opt_type = args.optimizer opt_type = args.optimizer
if opt_type == "sgd": if opt_type == "sgd":
...@@ -190,8 +189,10 @@ def main(): ...@@ -190,8 +189,10 @@ def main():
total_loss = 0 total_loss = 0
word_count = 0.0 word_count = 0.0
batch_times = [] batch_times = []
time_interval = 0.0
batch_start_time = time.time()
epoch_word_count = 0.0
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
input_data_feed, word_num = prepare_input( input_data_feed, word_num = prepare_input(
batch, epoch_id=epoch_id) batch, epoch_id=epoch_id)
word_count += word_num word_count += word_num
...@@ -206,27 +207,34 @@ def main(): ...@@ -206,27 +207,34 @@ def main():
batch_end_time = time.time() batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
time_interval += batch_time
epoch_word_count += word_num
if batch_id > 0 and batch_id % 100 == 0: if batch_id > 0 and batch_id % 100 == 0:
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" % print(
(epoch_id, batch_id, batch_time, "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; speed: %0.5f tokens/sec"
np.exp(total_loss / word_count))) % (epoch_id, batch_id, batch_time,
np.exp(total_loss / word_count),
word_count / time_interval))
ce_ppl.append(np.exp(total_loss / word_count)) ce_ppl.append(np.exp(total_loss / word_count))
total_loss = 0.0 total_loss = 0.0
word_count = 0.0 word_count = 0.0
time_interval = 0.0
# profiler tools # profiler tools
if args.profile and epoch_id == 0 and batch_id == 100: if args.profile and epoch_id == 0 and batch_id == 100:
profiler.reset_profiler() profiler.reset_profiler()
elif args.profile and epoch_id == 0 and batch_id == 105: elif args.profile and epoch_id == 0 and batch_id == 105:
return return
batch_start_time = time.time()
end_time = time.time() end_time = time.time()
epoch_time = end_time - start_time epoch_time = end_time - start_time
ce_time.append(epoch_time) ce_time.append(epoch_time)
print( print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n" "\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; speed: %0.5f tokens/sec\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times))) % (epoch_id, epoch_time, sum(batch_times) / len(batch_times),
epoch_word_count / sum(batch_times)))
if not args.profile: if not args.profile:
save_path = os.path.join(args.model_path, save_path = os.path.join(args.model_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册