未验证 提交 316ea549 编写于 作者: L liym27 提交者: GitHub

Revert to usage of 'fill_constant' in test_transformer. test=develop (#23529)

上级 ca7bd2be
......@@ -29,7 +29,7 @@ trainer_count = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 10
step_num = 10
STEP_NUM = 10
def train_static(args, batch_generator):
......@@ -109,7 +109,7 @@ def train_static(args, batch_generator):
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
"normalized loss: %f, ppl: %f, speed: %.2f steps/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
......@@ -118,7 +118,7 @@ def train_static(args, batch_generator):
batch_id += 1
step_idx += 1
total_batch_num = total_batch_num + 1
if step_idx == step_num:
if step_idx == STEP_NUM:
if args.save_dygraph_model_path:
model_path = os.path.join(args.save_static_model_path,
"transformer")
......@@ -193,7 +193,8 @@ def train_dygraph(args, batch_generator):
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
"normalized loss: %f, ppl: %f, speed: %.2f steps/s"
%
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
......@@ -202,7 +203,7 @@ def train_dygraph(args, batch_generator):
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
if step_idx == step_num:
if step_idx == STEP_NUM:
if args.save_dygraph_model_path:
model_dir = os.path.join(args.save_dygraph_model_path)
if not os.path.exists(model_dir):
......@@ -277,14 +278,14 @@ def predict_dygraph(args, batch_generator):
speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed)
logging.info(
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == step_num:
if step_idx == STEP_NUM:
break
logging.info("Dygraph Predict: avg_speed: %.4f step/s" %
logging.info("Dygraph Predict: avg_speed: %.4f steps/s" %
(np.mean(speed_list)))
return seq_ids, seq_scores
......@@ -353,14 +354,14 @@ def predict_static(args, batch_generator):
speed = args.print_step / (time.time() - avg_batch_time)
speed_list.append(speed)
logging.info(
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == step_num:
if step_idx == STEP_NUM:
break
logging.info("Static Predict: avg_speed: %.4f step/s" %
logging.info("Static Predict: avg_speed: %.4f steps/s" %
(np.mean(speed_list)))
return seq_ids, seq_scores
......
......@@ -608,8 +608,8 @@ class Transformer(Layer):
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.zeros_like(
trg_word) + i # TODO: modified for dygraph2static
trg_pos = layers.fill_constant(
shape=trg_word.shape, dtype="int64", value=i)
caches = map_structure(merge_batch_beams,
caches) # TODO: modified for dygraph2static
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册