提交 0d3c2924 编写于 作者: A andyjpaddle

fix head out

上级 8656a1dd
...@@ -26,7 +26,6 @@ import paddle.nn as nn ...@@ -26,7 +26,6 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierNormal from paddle.nn.initializer import Normal, XavierNormal
import numpy as np import numpy as np
from ppocr.modeling.backbones.rec_resnet_45 import ResNet45
class PositionalEncoding(nn.Layer): class PositionalEncoding(nn.Layer):
...@@ -237,7 +236,7 @@ class PP_layer(nn.Layer): ...@@ -237,7 +236,7 @@ class PP_layer(nn.Layer):
# enc_output: b,256,512 # enc_output: b,256,512
reading_order = paddle.arange(self.character_len, dtype='int64') reading_order = paddle.arange(self.character_len, dtype='int64')
reading_order = reading_order.unsqueeze(0).expand( reading_order = reading_order.unsqueeze(0).expand(
[enc_output.shape[0], -1]) # (S,) -> (B, S) [enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
reading_order = self.f0_embedding(reading_order) # b,25,512 reading_order = self.f0_embedding(reading_order) # b,25,512
# calculate attention # calculate attention
...@@ -431,32 +430,7 @@ class MLM_VRM(nn.Layer): ...@@ -431,32 +430,7 @@ class MLM_VRM(nn.Layer):
use_mlm=False) use_mlm=False)
text_pre = paddle.transpose( text_pre = paddle.transpose(
text_pre, perm=[1, 0, 2]) # (26, b, 37)) text_pre, perm=[1, 0, 2]) # (26, b, 37))
lenText = nT return text_pre, x
nsteps = nT
out_res = paddle.zeros(
shape=[lenText, b, self.nclass], dtype=x.dtype) # (25, b, 37)
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
now_step = 0
for _ in range(nsteps):
if 0 in out_length and now_step < nsteps:
tmp_result = text_pre[now_step, :, :]
out_res[now_step] = tmp_result
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
for j in range(b):
if out_length[j] == 0 and tmp_result[j] == 0:
out_length[j] = now_step + 1
now_step += 1
for j in range(0, b):
if int(out_length[j]) == 0:
out_length[j] = nsteps
start = 0
output = paddle.zeros(
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
for i in range(0, b):
cur_length = int(out_length[i])
output[start:start + cur_length] = out_res[0:cur_length, i, :]
start += cur_length
return output, out_length
class VLHead(nn.Layer): class VLHead(nn.Layer):
...@@ -489,6 +463,6 @@ class VLHead(nn.Layer): ...@@ -489,6 +463,6 @@ class VLHead(nn.Layer):
feat, label_pos, self.training_step, train_mode=True) feat, label_pos, self.training_step, train_mode=True)
return text_pre, test_rem, text_mas, mask_map return text_pre, test_rem, text_mas, mask_map
else: else:
output, out_length = self.MLM_VRM( text_pre, x = self.MLM_VRM(
feat, targets, self.training_step, train_mode=False) feat, targets, self.training_step, train_mode=False)
return output, out_length return text_pre, x
...@@ -675,6 +675,8 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -675,6 +675,8 @@ class VLLabelDecode(BaseRecLabelDecode):
def __init__(self, character_dict_path=None, use_space_char=False, def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs): **kwargs):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
self.nclass = len(self.character) + 1
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
...@@ -706,7 +708,40 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -706,7 +708,40 @@ class VLLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, length=None, *args, **kwargs): def __call__(self, preds, label=None, length=None, *args, **kwargs):
if len(preds) == 2: # eval mode if len(preds) == 2: # eval mode
net_out, length = preds text_pre, x = preds
b = text_pre.shape[1]
lenText = self.max_text_length
nsteps = self.max_text_length
if not isinstance(text_pre, paddle.Tensor):
text_pre = paddle.to_tensor(text_pre, dtype='float32')
out_res = paddle.zeros(
shape=[lenText, b, self.nclass], dtype=x.dtype)
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
now_step = 0
for _ in range(nsteps):
if 0 in out_length and now_step < nsteps:
tmp_result = text_pre[now_step, :, :]
out_res[now_step] = tmp_result
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
for j in range(b):
if out_length[j] == 0 and tmp_result[j] == 0:
out_length[j] = now_step + 1
now_step += 1
for j in range(0, b):
if int(out_length[j]) == 0:
out_length[j] = nsteps
start = 0
output = paddle.zeros(
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
for i in range(0, b):
cur_length = int(out_length[i])
output[start:start + cur_length] = out_res[0:cur_length, i, :]
start += cur_length
net_out = output
length = out_length
else: # train mode else: # train mode
net_out = preds[0] net_out = preds[0]
length = length length = length
...@@ -714,8 +749,6 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -714,8 +749,6 @@ class VLLabelDecode(BaseRecLabelDecode):
text = [] text = []
if not isinstance(net_out, paddle.Tensor): if not isinstance(net_out, paddle.Tensor):
net_out = paddle.to_tensor(net_out, dtype='float32') net_out = paddle.to_tensor(net_out, dtype='float32')
# import pdb
# pdb.set_trace()
net_out = F.softmax(net_out, axis=1) net_out = F.softmax(net_out, axis=1)
for i in range(0, length.shape[0]): for i in range(0, length.shape[0]):
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum( preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册