未验证 提交 dba5be0f 编写于 作者: Z zhengya01 提交者: GitHub

add ce for ptb_lm (#4328)

上级 f0a09c3a
...@@ -220,6 +220,8 @@ def train_ptb_lm(): ...@@ -220,6 +220,8 @@ def train_ptb_lm():
if args.use_gpu == True: if args.use_gpu == True:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
# check if paddlepaddle version is satisfied # check if paddlepaddle version is satisfied
model_check.check_version() model_check.check_version()
...@@ -363,9 +365,9 @@ def train_ptb_lm(): ...@@ -363,9 +365,9 @@ def train_ptb_lm():
print("eval finished") print("eval finished")
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
print("ppl ", batch_id, ppl[0]) print("ppl ", batch_id, ppl[0])
if args.ce:
print("kpis\ttest_ppl\t%0.3f" % ppl[0])
ce_time = []
ce_ppl = []
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(max_grad_norm) grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(max_grad_norm)
for epoch_id in range(max_epoch): for epoch_id in range(max_epoch):
ptb_model.train() ptb_model.train()
...@@ -412,6 +414,8 @@ def train_ptb_lm(): ...@@ -412,6 +414,8 @@ def train_ptb_lm():
print("one epoch finished", epoch_id) print("one epoch finished", epoch_id)
print("time cost ", time.time() - start_time) print("time cost ", time.time() - start_time)
ppl = np.exp(total_loss / iters) ppl = np.exp(total_loss / iters)
ce_time.append(time.time() - start_time)
ce_ppl.append(ppl[0])
print("-- Epoch:[%d]; ppl: %.5f" % (epoch_id, ppl[0])) print("-- Epoch:[%d]; ppl: %.5f" % (epoch_id, ppl[0]))
if batch_size <= 20 and epoch_id == 0 and ppl[0] > 1000: if batch_size <= 20 and epoch_id == 0 and ppl[0] > 1000:
...@@ -421,8 +425,6 @@ def train_ptb_lm(): ...@@ -421,8 +425,6 @@ def train_ptb_lm():
print("Abort this training process and please start again.") print("Abort this training process and please start again.")
return return
if args.ce:
print("kpis\ttrain_ppl\t%0.3f" % ppl[0])
save_model_dir = os.path.join(args.save_model_dir, save_model_dir = os.path.join(args.save_model_dir,
str(epoch_id), 'params') str(epoch_id), 'params')
fluid.save_dygraph(ptb_model.state_dict(), save_model_dir) fluid.save_dygraph(ptb_model.state_dict(), save_model_dir)
...@@ -430,6 +432,17 @@ def train_ptb_lm(): ...@@ -430,6 +432,17 @@ def train_ptb_lm():
eval(ptb_model, valid_data) eval(ptb_model, valid_data)
if args.ce:
_ppl = 0
_time = 0
try:
_time = ce_time[-1]
_ppl = ce_ppl[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (dev_count, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (dev_count, _ppl))
eval(ptb_model, test_data) eval(ptb_model, test_data)
train_ptb_lm() train_ptb_lm()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册