From d810d4859e30e40c5eb51693b39540202f85b0fc Mon Sep 17 00:00:00 2001 From: topduke <784990967@qq.com> Date: Wed, 26 Oct 2022 14:19:20 +0800 Subject: [PATCH] Update rec_nrtr_head.py (#8108) fix bug when export nrtr model with paddlepaddle develop version --- ppocr/modeling/heads/rec_nrtr_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py index bf9ef561..eb279400 100644 --- a/ppocr/modeling/heads/rec_nrtr_head.py +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -162,7 +162,7 @@ class Transformer(nn.Layer): memory = src dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32) - for len_dec_seq in range(1, self.max_len): + for len_dec_seq in range(1, paddle.to_tensor(self.max_len)): dec_seq_embed = self.embedding(dec_seq) dec_seq_embed = self.positional_encoding(dec_seq_embed) tgt_mask = self.generate_square_subsequent_mask( @@ -304,7 +304,7 @@ class Transformer(nn.Layer): inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) # Decode - for len_dec_seq in range(1, self.max_len): + for len_dec_seq in range(1, paddle.to_tensor(self.max_len)): src_enc_copy = src_enc.clone() active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_enc_copy, -- GitLab