diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py index bf9ef56145e6edfb15bd30235b4a62588396ba96..eb279400203b9ef173793bc0d90e5ab99701cb3a 100644 --- a/ppocr/modeling/heads/rec_nrtr_head.py +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -162,7 +162,7 @@ class Transformer(nn.Layer): memory = src dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32) - for len_dec_seq in range(1, self.max_len): + for len_dec_seq in range(1, paddle.to_tensor(self.max_len)): dec_seq_embed = self.embedding(dec_seq) dec_seq_embed = self.positional_encoding(dec_seq_embed) tgt_mask = self.generate_square_subsequent_mask( @@ -304,7 +304,7 @@ class Transformer(nn.Layer): inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) # Decode - for len_dec_seq in range(1, self.max_len): + for len_dec_seq in range(1, paddle.to_tensor(self.max_len)): src_enc_copy = src_enc.clone() active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_enc_copy,