未验证 提交 316ea549 编写于 作者: L liym27 提交者: GitHub

Revert to usage of 'fill_constant' in test_transformer. test=develop (#23529)

上级 ca7bd2be
...@@ -29,7 +29,7 @@ trainer_count = 1 ...@@ -29,7 +29,7 @@ trainer_count = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
) )
SEED = 10 SEED = 10
step_num = 10 STEP_NUM = 10
def train_static(args, batch_generator): def train_static(args, batch_generator):
...@@ -109,7 +109,7 @@ def train_static(args, batch_generator): ...@@ -109,7 +109,7 @@ def train_static(args, batch_generator):
else: else:
logging.info( logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" % "normalized loss: %f, ppl: %f, speed: %.2f steps/s" %
(step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]), np.exp([min(total_avg_cost, 100)]),
...@@ -118,7 +118,7 @@ def train_static(args, batch_generator): ...@@ -118,7 +118,7 @@ def train_static(args, batch_generator):
batch_id += 1 batch_id += 1
step_idx += 1 step_idx += 1
total_batch_num = total_batch_num + 1 total_batch_num = total_batch_num + 1
if step_idx == step_num: if step_idx == STEP_NUM:
if args.save_dygraph_model_path: if args.save_dygraph_model_path:
model_path = os.path.join(args.save_static_model_path, model_path = os.path.join(args.save_static_model_path,
"transformer") "transformer")
...@@ -193,7 +193,8 @@ def train_dygraph(args, batch_generator): ...@@ -193,7 +193,8 @@ def train_dygraph(args, batch_generator):
else: else:
logging.info( logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" % "normalized loss: %f, ppl: %f, speed: %.2f steps/s"
%
(step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]), np.exp([min(total_avg_cost, 100)]),
...@@ -202,7 +203,7 @@ def train_dygraph(args, batch_generator): ...@@ -202,7 +203,7 @@ def train_dygraph(args, batch_generator):
avg_batch_time = time.time() avg_batch_time = time.time()
batch_id += 1 batch_id += 1
step_idx += 1 step_idx += 1
if step_idx == step_num: if step_idx == STEP_NUM:
if args.save_dygraph_model_path: if args.save_dygraph_model_path:
model_dir = os.path.join(args.save_dygraph_model_path) model_dir = os.path.join(args.save_dygraph_model_path)
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
...@@ -277,14 +278,14 @@ def predict_dygraph(args, batch_generator): ...@@ -277,14 +278,14 @@ def predict_dygraph(args, batch_generator):
speed = args.print_step / (time.time() - avg_batch_time) speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed) speed_list.append(speed)
logging.info( logging.info(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time() avg_batch_time = time.time()
step_idx += 1 step_idx += 1
if step_idx == step_num: if step_idx == STEP_NUM:
break break
logging.info("Dygraph Predict: avg_speed: %.4f step/s" % logging.info("Dygraph Predict: avg_speed: %.4f steps/s" %
(np.mean(speed_list))) (np.mean(speed_list)))
return seq_ids, seq_scores return seq_ids, seq_scores
...@@ -353,14 +354,14 @@ def predict_static(args, batch_generator): ...@@ -353,14 +354,14 @@ def predict_static(args, batch_generator):
speed = args.print_step / (time.time() - avg_batch_time) speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed) speed_list.append(speed)
logging.info( logging.info(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s" "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)) % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time() avg_batch_time = time.time()
step_idx += 1 step_idx += 1
if step_idx == step_num: if step_idx == STEP_NUM:
break break
logging.info("Static Predict: avg_speed: %.4f step/s" % logging.info("Static Predict: avg_speed: %.4f steps/s" %
(np.mean(speed_list))) (np.mean(speed_list)))
return seq_ids, seq_scores return seq_ids, seq_scores
......
...@@ -608,8 +608,8 @@ class Transformer(Layer): ...@@ -608,8 +608,8 @@ class Transformer(Layer):
} for i in range(self.n_layer)] } for i in range(self.n_layer)]
for i in range(max_len): for i in range(max_len):
trg_pos = layers.zeros_like( trg_pos = layers.fill_constant(
trg_word) + i # TODO: modified for dygraph2static shape=trg_word.shape, dtype="int64", value=i)
caches = map_structure(merge_batch_beams, caches = map_structure(merge_batch_beams,
caches) # TODO: modified for dygraph2static caches) # TODO: modified for dygraph2static
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册