未验证 提交 50bd7159 编写于 作者: Z zhengya01 提交者: GitHub

add ce for dygraph_seq2seq (#4274)

上级 5cb47352
......@@ -57,7 +57,7 @@ def main():
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place):
args.enable_ce = True
#args.enable_ce = True
if args.enable_ce:
fluid.default_startup_program().random_seed = 102
fluid.default_main_program().random_seed = 102
......@@ -138,6 +138,8 @@ def main():
model.train()
return ppl
ce_time = []
ce_ppl = []
max_epoch = args.max_epoch
for epoch_id in range(max_epoch):
model.train()
......@@ -170,6 +172,7 @@ def main():
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
(epoch_id, batch_id, batch_time,
np.exp(total_loss.numpy() / word_count)))
ce_ppl.append(np.exp(total_loss.numpy() / word_count))
total_loss = 0.0
word_count = 0.0
......@@ -178,6 +181,7 @@ def main():
print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
ce_time.append(epoch_time)
dir_name = os.path.join(args.model_path,
......@@ -190,6 +194,18 @@ def main():
test_ppl = eval(test_data)
print("test ppl", test_ppl)
if args.enable_ce:
card_num = get_cards()
_ppl = 0
_time = 0
try:
_time = ce_time[-1]
_ppl = ce_ppl[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
def get_cards():
num = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册