diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py index 954a40779d0ffcdbbf02b722ea0609bcda555f67..588d2b1f207308fac25eec9c758cbbba9d00b87e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py @@ -29,7 +29,7 @@ trainer_count = 1 place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ) SEED = 10 -step_num = 10 +STEP_NUM = 10 def train_static(args, batch_generator): @@ -109,7 +109,7 @@ def train_static(args, batch_generator): else: logging.info( "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " - "normalized loss: %f, ppl: %f, speed: %.2f step/s" % + "normalized loss: %f, ppl: %f, speed: %.2f steps/s" % (step_idx, pass_id, batch_id, total_avg_cost, total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]), @@ -118,7 +118,7 @@ def train_static(args, batch_generator): batch_id += 1 step_idx += 1 total_batch_num = total_batch_num + 1 - if step_idx == step_num: + if step_idx == STEP_NUM: if args.save_dygraph_model_path: model_path = os.path.join(args.save_static_model_path, "transformer") @@ -193,7 +193,8 @@ def train_dygraph(args, batch_generator): else: logging.info( "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " - "normalized loss: %f, ppl: %f, speed: %.2f step/s" % + "normalized loss: %f, ppl: %f, speed: %.2f steps/s" + % (step_idx, pass_id, batch_id, total_avg_cost, total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]), @@ -202,7 +203,7 @@ def train_dygraph(args, batch_generator): avg_batch_time = time.time() batch_id += 1 step_idx += 1 - if step_idx == step_num: + if step_idx == STEP_NUM: if args.save_dygraph_model_path: model_dir = os.path.join(args.save_dygraph_model_path) if not os.path.exists(model_dir): @@ -277,14 +278,14 @@ def predict_dygraph(args, batch_generator): speed = args.print_step / (time.time() - avg_batch_time) speed_list.append(speed) logging.info( - "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" + "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s" % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) avg_batch_time = time.time() step_idx += 1 - if step_idx == step_num: + if step_idx == STEP_NUM: break - logging.info("Dygraph Predict: avg_speed: %.4f step/s" % + logging.info("Dygraph Predict: avg_speed: %.4f steps/s" % (np.mean(speed_list))) return seq_ids, seq_scores @@ -353,14 +354,14 @@ def predict_static(args, batch_generator): speed = args.print_step / (time.time() - avg_batch_time) speed_list.append(speed) logging.info( - "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" + "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s" % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) avg_batch_time = time.time() step_idx += 1 - if step_idx == step_num: + if step_idx == STEP_NUM: break - logging.info("Static Predict: avg_speed: %.4f step/s" % + logging.info("Static Predict: avg_speed: %.4f steps/s" % (np.mean(speed_list))) return seq_ids, seq_scores diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py index 669106b45b4f93e36617050e07f86cf01390fabe..27b24e120d74898cfd1d56855e7983cb1bc29b08 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py @@ -608,8 +608,8 @@ class Transformer(Layer): } for i in range(self.n_layer)] for i in range(max_len): - trg_pos = layers.zeros_like( - trg_word) + i # TODO: modified for dygraph2static + trg_pos = layers.fill_constant( + shape=trg_word.shape, dtype="int64", value=i) caches = map_structure(merge_batch_beams, caches) # TODO: modified for dygraph2static logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,