未验证 提交 d810d485 编写于 作者: T topduke 提交者: GitHub

Update rec_nrtr_head.py (#8108)

fix bug when export nrtr model with paddlepaddle develop version
上级 f2835168
......@@ -162,7 +162,7 @@ class Transformer(nn.Layer):
memory = src
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
for len_dec_seq in range(1, self.max_len):
for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
dec_seq_embed = self.embedding(dec_seq)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(
......@@ -304,7 +304,7 @@ class Transformer(nn.Layer):
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list)
# Decode
for len_dec_seq in range(1, self.max_len):
for len_dec_seq in range(1, paddle.to_tensor(self.max_len)):
src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step(
inst_dec_beams, len_dec_seq, src_enc_copy,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册