未验证 提交 660489ac 编写于 作者: L liym27 提交者: GitHub

Add log and check predicted scores. test=develop (#23506)

上级 9bc223c8
......@@ -29,6 +29,7 @@ trainer_count = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 10
step_num = 10
def train_static(args, batch_generator):
......@@ -117,7 +118,7 @@ def train_static(args, batch_generator):
batch_id += 1
step_idx += 1
total_batch_num = total_batch_num + 1
if step_idx == 10:
if step_idx == step_num:
if args.save_dygraph_model_path:
model_path = os.path.join(args.save_static_model_path,
"transformer")
......@@ -201,7 +202,7 @@ def train_dygraph(args, batch_generator):
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
if step_idx == 10:
if step_idx == step_num:
if args.save_dygraph_model_path:
model_dir = os.path.join(args.save_dygraph_model_path)
if not os.path.exists(model_dir):
......@@ -250,10 +251,11 @@ def predict_dygraph(args, batch_generator):
transformer.eval()
step_idx = 0
speed_list = []
for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data
finished_seq, finished_scores = transformer.beam_search(
seq_ids, seq_scores = transformer.beam_search(
src_word,
src_pos,
src_slf_attn_bias,
......@@ -263,12 +265,28 @@ def predict_dygraph(args, batch_generator):
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_len=args.max_out_len)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
seq_ids = seq_ids.numpy()
seq_scores = seq_scores.numpy()
if step_idx % args.print_step == 0:
if step_idx == 0:
logging.info(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0]))
avg_batch_time = time.time()
else:
speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed)
logging.info(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == 10:
if step_idx == step_num:
break
return finished_seq
logging.info("Dygraph Predict: avg_speed: %.4f step/s" %
(np.mean(speed_list)))
return seq_ids, seq_scores
def predict_static(args, batch_generator):
......@@ -318,16 +336,34 @@ def predict_static(args, batch_generator):
loader.set_batch_generator(batch_generator, places=place)
step_idx = 0
speed_list = []
for feed_dict in loader:
seq_ids, seq_scores = exe.run(
test_prog,
feed=feed_dict,
fetch_list=[out_ids.name, out_scores.name],
return_numpy=True)
if step_idx % args.print_step == 0:
if step_idx == 0:
logging.info(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f,"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0]))
avg_batch_time = time.time()
else:
speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed)
logging.info(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == 10:
if step_idx == step_num:
break
return seq_ids
logging.info("Static Predict: avg_speed: %.4f step/s" %
(np.mean(speed_list)))
return seq_ids, seq_scores
class TestTransformer(unittest.TestCase):
......@@ -344,12 +380,17 @@ class TestTransformer(unittest.TestCase):
def _test_predict(self):
args, batch_generator = self.prepare(mode='test')
static_res = predict_static(args, batch_generator)
dygraph_res = predict_dygraph(args, batch_generator)
static_seq_ids, static_scores = predict_static(args, batch_generator)
dygraph_seq_ids, dygraph_scores = predict_dygraph(args, batch_generator)
self.assertTrue(
np.allclose(static_seq_ids, static_seq_ids),
msg="static_seq_ids: {} \n dygraph_seq_ids: {}".format(
static_seq_ids, dygraph_seq_ids))
self.assertTrue(
np.allclose(static_res, dygraph_res),
msg="static_res: {} \n dygraph_res: {}".format(static_res,
dygraph_res))
np.allclose(static_scores, dygraph_scores),
msg="static_scores: {} \n dygraph_scores: {}".format(
static_scores, dygraph_scores))
def test_check_result(self):
self._test_train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册