diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py index 6ec92976fbf3e9cad2da3fa80ee8d6a614b1f8b5..86054d9bbb12613e3119b4c0d72f4670344d773a 100644 --- a/ppocr/modeling/heads/rec_visionlan_head.py +++ b/ppocr/modeling/heads/rec_visionlan_head.py @@ -26,7 +26,6 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import Normal, XavierNormal import numpy as np -from ppocr.modeling.backbones.rec_resnet_45 import ResNet45 class PositionalEncoding(nn.Layer): @@ -237,7 +236,7 @@ class PP_layer(nn.Layer): # enc_output: b,256,512 reading_order = paddle.arange(self.character_len, dtype='int64') reading_order = reading_order.unsqueeze(0).expand( - [enc_output.shape[0], -1]) # (S,) -> (B, S) + [enc_output.shape[0], self.character_len]) # (S,) -> (B, S) reading_order = self.f0_embedding(reading_order) # b,25,512 # calculate attention @@ -431,32 +430,7 @@ class MLM_VRM(nn.Layer): use_mlm=False) text_pre = paddle.transpose( text_pre, perm=[1, 0, 2]) # (26, b, 37)) - lenText = nT - nsteps = nT - out_res = paddle.zeros( - shape=[lenText, b, self.nclass], dtype=x.dtype) # (25, b, 37) - out_length = paddle.zeros(shape=[b], dtype=x.dtype) - now_step = 0 - for _ in range(nsteps): - if 0 in out_length and now_step < nsteps: - tmp_result = text_pre[now_step, :, :] - out_res[now_step] = tmp_result - tmp_result = tmp_result.topk(1)[1].squeeze(axis=1) - for j in range(b): - if out_length[j] == 0 and tmp_result[j] == 0: - out_length[j] = now_step + 1 - now_step += 1 - for j in range(0, b): - if int(out_length[j]) == 0: - out_length[j] = nsteps - start = 0 - output = paddle.zeros( - shape=[int(out_length.sum()), self.nclass], dtype=x.dtype) - for i in range(0, b): - cur_length = int(out_length[i]) - output[start:start + cur_length] = out_res[0:cur_length, i, :] - start += cur_length - return output, out_length + return text_pre, x class VLHead(nn.Layer): @@ -489,6 +463,6 @@ class VLHead(nn.Layer): feat, label_pos, self.training_step, train_mode=True) return text_pre, test_rem, text_mas, mask_map else: - output, out_length = self.MLM_VRM( + text_pre, x = self.MLM_VRM( feat, targets, self.training_step, train_mode=False) - return output, out_length + return text_pre, x diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index fa287c0254a365a13ab6424a7125054ebb7344a4..75f4754bf2ed4ecd4dc0614d93aa9ab375889800 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -675,6 +675,8 @@ class VLLabelDecode(BaseRecLabelDecode): def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) + self.max_text_length = kwargs.get('max_text_length', 25) + self.nclass = len(self.character) + 1 def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """ @@ -706,7 +708,40 @@ class VLLabelDecode(BaseRecLabelDecode): def __call__(self, preds, label=None, length=None, *args, **kwargs): if len(preds) == 2: # eval mode - net_out, length = preds + text_pre, x = preds + b = text_pre.shape[1] + lenText = self.max_text_length + nsteps = self.max_text_length + + if not isinstance(text_pre, paddle.Tensor): + text_pre = paddle.to_tensor(text_pre, dtype='float32') + + out_res = paddle.zeros( + shape=[lenText, b, self.nclass], dtype=x.dtype) + out_length = paddle.zeros(shape=[b], dtype=x.dtype) + now_step = 0 + for _ in range(nsteps): + if 0 in out_length and now_step < nsteps: + tmp_result = text_pre[now_step, :, :] + out_res[now_step] = tmp_result + tmp_result = tmp_result.topk(1)[1].squeeze(axis=1) + for j in range(b): + if out_length[j] == 0 and tmp_result[j] == 0: + out_length[j] = now_step + 1 + now_step += 1 + for j in range(0, b): + if int(out_length[j]) == 0: + out_length[j] = nsteps + start = 0 + output = paddle.zeros( + shape=[int(out_length.sum()), self.nclass], dtype=x.dtype) + for i in range(0, b): + cur_length = int(out_length[i]) + output[start:start + cur_length] = out_res[0:cur_length, i, :] + start += cur_length + net_out = output + length = out_length + else: # train mode net_out = preds[0] length = length @@ -714,8 +749,6 @@ class VLLabelDecode(BaseRecLabelDecode): text = [] if not isinstance(net_out, paddle.Tensor): net_out = paddle.to_tensor(net_out, dtype='float32') - # import pdb - # pdb.set_trace() net_out = F.softmax(net_out, axis=1) for i in range(0, length.shape[0]): preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(