Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
0d3c2924
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0d3c2924
编写于
7月 28, 2022
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix head out
上级
8656a1dd
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
40 addition
and
33 deletion
+40
-33
ppocr/modeling/heads/rec_visionlan_head.py
ppocr/modeling/heads/rec_visionlan_head.py
+4
-30
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+36
-3
未找到文件。
ppocr/modeling/heads/rec_visionlan_head.py
浏览文件 @
0d3c2924
...
...
@@ -26,7 +26,6 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
XavierNormal
import
numpy
as
np
from
ppocr.modeling.backbones.rec_resnet_45
import
ResNet45
class
PositionalEncoding
(
nn
.
Layer
):
...
...
@@ -237,7 +236,7 @@ class PP_layer(nn.Layer):
# enc_output: b,256,512
reading_order
=
paddle
.
arange
(
self
.
character_len
,
dtype
=
'int64'
)
reading_order
=
reading_order
.
unsqueeze
(
0
).
expand
(
[
enc_output
.
shape
[
0
],
-
1
])
# (S,) -> (B, S)
[
enc_output
.
shape
[
0
],
self
.
character_len
])
# (S,) -> (B, S)
reading_order
=
self
.
f0_embedding
(
reading_order
)
# b,25,512
# calculate attention
...
...
@@ -431,32 +430,7 @@ class MLM_VRM(nn.Layer):
use_mlm
=
False
)
text_pre
=
paddle
.
transpose
(
text_pre
,
perm
=
[
1
,
0
,
2
])
# (26, b, 37))
lenText
=
nT
nsteps
=
nT
out_res
=
paddle
.
zeros
(
shape
=
[
lenText
,
b
,
self
.
nclass
],
dtype
=
x
.
dtype
)
# (25, b, 37)
out_length
=
paddle
.
zeros
(
shape
=
[
b
],
dtype
=
x
.
dtype
)
now_step
=
0
for
_
in
range
(
nsteps
):
if
0
in
out_length
and
now_step
<
nsteps
:
tmp_result
=
text_pre
[
now_step
,
:,
:]
out_res
[
now_step
]
=
tmp_result
tmp_result
=
tmp_result
.
topk
(
1
)[
1
].
squeeze
(
axis
=
1
)
for
j
in
range
(
b
):
if
out_length
[
j
]
==
0
and
tmp_result
[
j
]
==
0
:
out_length
[
j
]
=
now_step
+
1
now_step
+=
1
for
j
in
range
(
0
,
b
):
if
int
(
out_length
[
j
])
==
0
:
out_length
[
j
]
=
nsteps
start
=
0
output
=
paddle
.
zeros
(
shape
=
[
int
(
out_length
.
sum
()),
self
.
nclass
],
dtype
=
x
.
dtype
)
for
i
in
range
(
0
,
b
):
cur_length
=
int
(
out_length
[
i
])
output
[
start
:
start
+
cur_length
]
=
out_res
[
0
:
cur_length
,
i
,
:]
start
+=
cur_length
return
output
,
out_length
return
text_pre
,
x
class
VLHead
(
nn
.
Layer
):
...
...
@@ -489,6 +463,6 @@ class VLHead(nn.Layer):
feat
,
label_pos
,
self
.
training_step
,
train_mode
=
True
)
return
text_pre
,
test_rem
,
text_mas
,
mask_map
else
:
output
,
out_length
=
self
.
MLM_VRM
(
text_pre
,
x
=
self
.
MLM_VRM
(
feat
,
targets
,
self
.
training_step
,
train_mode
=
False
)
return
output
,
out_length
return
text_pre
,
x
ppocr/postprocess/rec_postprocess.py
浏览文件 @
0d3c2924
...
...
@@ -675,6 +675,8 @@ class VLLabelDecode(BaseRecLabelDecode):
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
VLLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
self
.
max_text_length
=
kwargs
.
get
(
'max_text_length'
,
25
)
self
.
nclass
=
len
(
self
.
character
)
+
1
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
...
...
@@ -706,7 +708,40 @@ class VLLabelDecode(BaseRecLabelDecode):
def
__call__
(
self
,
preds
,
label
=
None
,
length
=
None
,
*
args
,
**
kwargs
):
if
len
(
preds
)
==
2
:
# eval mode
net_out
,
length
=
preds
text_pre
,
x
=
preds
b
=
text_pre
.
shape
[
1
]
lenText
=
self
.
max_text_length
nsteps
=
self
.
max_text_length
if
not
isinstance
(
text_pre
,
paddle
.
Tensor
):
text_pre
=
paddle
.
to_tensor
(
text_pre
,
dtype
=
'float32'
)
out_res
=
paddle
.
zeros
(
shape
=
[
lenText
,
b
,
self
.
nclass
],
dtype
=
x
.
dtype
)
out_length
=
paddle
.
zeros
(
shape
=
[
b
],
dtype
=
x
.
dtype
)
now_step
=
0
for
_
in
range
(
nsteps
):
if
0
in
out_length
and
now_step
<
nsteps
:
tmp_result
=
text_pre
[
now_step
,
:,
:]
out_res
[
now_step
]
=
tmp_result
tmp_result
=
tmp_result
.
topk
(
1
)[
1
].
squeeze
(
axis
=
1
)
for
j
in
range
(
b
):
if
out_length
[
j
]
==
0
and
tmp_result
[
j
]
==
0
:
out_length
[
j
]
=
now_step
+
1
now_step
+=
1
for
j
in
range
(
0
,
b
):
if
int
(
out_length
[
j
])
==
0
:
out_length
[
j
]
=
nsteps
start
=
0
output
=
paddle
.
zeros
(
shape
=
[
int
(
out_length
.
sum
()),
self
.
nclass
],
dtype
=
x
.
dtype
)
for
i
in
range
(
0
,
b
):
cur_length
=
int
(
out_length
[
i
])
output
[
start
:
start
+
cur_length
]
=
out_res
[
0
:
cur_length
,
i
,
:]
start
+=
cur_length
net_out
=
output
length
=
out_length
else
:
# train mode
net_out
=
preds
[
0
]
length
=
length
...
...
@@ -714,8 +749,6 @@ class VLLabelDecode(BaseRecLabelDecode):
text
=
[]
if
not
isinstance
(
net_out
,
paddle
.
Tensor
):
net_out
=
paddle
.
to_tensor
(
net_out
,
dtype
=
'float32'
)
# import pdb
# pdb.set_trace()
net_out
=
F
.
softmax
(
net_out
,
axis
=
1
)
for
i
in
range
(
0
,
length
.
shape
[
0
]):
preds_idx
=
net_out
[
int
(
length
[:
i
].
sum
()):
int
(
length
[:
i
].
sum
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录