提交 ef88e63f 编写于 作者: G guosheng

Refine the training output in Transformer

上级 ad1c3857
import os import os
import time
import numpy as np import numpy as np
import paddle import paddle
...@@ -159,6 +160,7 @@ def main(): ...@@ -159,6 +160,7 @@ def main():
ModelHyperParams.d_model), place) ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all # The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently. # mini-batches must have the same number of instances currently.
...@@ -175,15 +177,17 @@ def main(): ...@@ -175,15 +177,17 @@ def main():
fetch_list=[sum_cost, avg_cost], fetch_list=[sum_cost, avg_cost],
use_program_cache=True) use_program_cache=True)
sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1]) sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
" sum_cost = " + str(sum_cost_val) + " avg_cost = " + str( (pass_id, batch_id, sum_cost_val, avg_cost_val,
avg_cost_val) + " ppl = " + str( np.exp([min(avg_cost_val[0], 100)])))
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference. # Validate and save the model for inference.
val_sum_cost, val_avg_cost = test(exe) val_sum_cost, val_avg_cost = test(exe)
print("pass_id = " + str(pass_id) + " val_sum_cost = " + str( pass_end_time = time.time()
val_sum_cost) + " val_avg_cost = " + str(val_avg_cost) + time_consumed = pass_end_time - pass_start_time
" val_ppl = " + str(np.exp(min(val_avg_cost, 100)))) print(
"epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, consumed %fs"
% (pass_id, val_sum_cost, val_avg_cost,
np.exp([min(val_avg_cost, 100)]), time_consumed))
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册