From 660489acac9cb30a59b201b9814045f33ec87c6c Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 7 Apr 2020 11:18:06 +0800 Subject: [PATCH] Add log and check predicted scores. test=develop (#23506) --- .../dygraph_to_static/test_transformer.py | 69 +++++++++++++++---- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py index 7d4ccc45ae2..954a40779d0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py @@ -29,6 +29,7 @@ trainer_count = 1 place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ) SEED = 10 +step_num = 10 def train_static(args, batch_generator): @@ -117,7 +118,7 @@ def train_static(args, batch_generator): batch_id += 1 step_idx += 1 total_batch_num = total_batch_num + 1 - if step_idx == 10: + if step_idx == step_num: if args.save_dygraph_model_path: model_path = os.path.join(args.save_static_model_path, "transformer") @@ -201,7 +202,7 @@ def train_dygraph(args, batch_generator): avg_batch_time = time.time() batch_id += 1 step_idx += 1 - if step_idx == 10: + if step_idx == step_num: if args.save_dygraph_model_path: model_dir = os.path.join(args.save_dygraph_model_path) if not os.path.exists(model_dir): @@ -250,10 +251,11 @@ def predict_dygraph(args, batch_generator): transformer.eval() step_idx = 0 + speed_list = [] for input_data in test_loader(): (src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias) = input_data - finished_seq, finished_scores = transformer.beam_search( + seq_ids, seq_scores = transformer.beam_search( src_word, src_pos, src_slf_attn_bias, @@ -263,12 +265,28 @@ def predict_dygraph(args, batch_generator): eos_id=args.eos_idx, beam_size=args.beam_size, max_len=args.max_out_len) - finished_seq = finished_seq.numpy() - finished_scores = finished_scores.numpy() + seq_ids = seq_ids.numpy() + seq_scores = seq_scores.numpy() + if step_idx % args.print_step == 0: + if step_idx == 0: + logging.info( + "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f" + % (step_idx, seq_ids[0][0][0], seq_scores[0][0])) + avg_batch_time = time.time() + else: + speed = args.print_step / (time.time() - avg_batch_time) + speed_list.append(speed) + logging.info( + "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" + % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) + avg_batch_time = time.time() + step_idx += 1 - if step_idx == 10: + if step_idx == step_num: break - return finished_seq + logging.info("Dygraph Predict: avg_speed: %.4f step/s" % + (np.mean(speed_list))) + return seq_ids, seq_scores def predict_static(args, batch_generator): @@ -318,16 +336,34 @@ def predict_static(args, batch_generator): loader.set_batch_generator(batch_generator, places=place) step_idx = 0 + speed_list = [] for feed_dict in loader: seq_ids, seq_scores = exe.run( test_prog, feed=feed_dict, fetch_list=[out_ids.name, out_scores.name], return_numpy=True) + if step_idx % args.print_step == 0: + if step_idx == 0: + logging.info( + "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f," + % (step_idx, seq_ids[0][0][0], seq_scores[0][0])) + avg_batch_time = time.time() + else: + speed = args.print_step / (time.time() - avg_batch_time) + speed_list.append(speed) + logging.info( + "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" + % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) + avg_batch_time = time.time() + step_idx += 1 - if step_idx == 10: + if step_idx == step_num: break - return seq_ids + logging.info("Static Predict: avg_speed: %.4f step/s" % + (np.mean(speed_list))) + + return seq_ids, seq_scores class TestTransformer(unittest.TestCase): @@ -344,12 +380,17 @@ class TestTransformer(unittest.TestCase): def _test_predict(self): args, batch_generator = self.prepare(mode='test') - static_res = predict_static(args, batch_generator) - dygraph_res = predict_dygraph(args, batch_generator) + static_seq_ids, static_scores = predict_static(args, batch_generator) + dygraph_seq_ids, dygraph_scores = predict_dygraph(args, batch_generator) + + self.assertTrue( + np.allclose(static_seq_ids, static_seq_ids), + msg="static_seq_ids: {} \n dygraph_seq_ids: {}".format( + static_seq_ids, dygraph_seq_ids)) self.assertTrue( - np.allclose(static_res, dygraph_res), - msg="static_res: {} \n dygraph_res: {}".format(static_res, - dygraph_res)) + np.allclose(static_scores, dygraph_scores), + msg="static_scores: {} \n dygraph_scores: {}".format( + static_scores, dygraph_scores)) def test_check_result(self): self._test_train() -- GitLab