Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
1b190503
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1b190503
编写于
8月 25, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adaptation of Chinese and r34/18
上级
7b201a38
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
64 addition
and
25 deletion
+64
-25
ppocr/data/rec/dataset_traversal.py
ppocr/data/rec/dataset_traversal.py
+34
-14
ppocr/data/rec/img_tools.py
ppocr/data/rec/img_tools.py
+7
-5
ppocr/modeling/backbones/rec_resnet50_fpn.py
ppocr/modeling/backbones/rec_resnet50_fpn.py
+17
-0
ppocr/utils/character.py
ppocr/utils/character.py
+4
-4
tools/infer_rec.py
tools/infer_rec.py
+2
-2
未找到文件。
ppocr/data/rec/dataset_traversal.py
浏览文件 @
1b190503
...
...
@@ -214,6 +214,8 @@ class SimpleReader(object):
self
.
mode
=
params
[
'mode'
]
self
.
infer_img
=
params
[
'infer_img'
]
self
.
use_tps
=
False
if
"num_heads"
in
params
:
self
.
num_heads
=
params
[
'num_heads'
]
if
"tps"
in
params
:
self
.
use_tps
=
True
self
.
use_distort
=
False
...
...
@@ -251,6 +253,13 @@ class SimpleReader(object):
img
=
cv2
.
imread
(
single_img
)
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
self
.
loss_type
==
'srn'
:
norm_img
=
process_image_srn
(
img
=
img
,
image_shape
=
self
.
image_shape
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
)
else
:
norm_img
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
...
...
@@ -286,6 +295,17 @@ class SimpleReader(object):
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
label
=
substr
[
1
]
if
self
.
loss_type
==
"srn"
:
outs
=
process_image_srn
(
img
=
img
,
image_shape
=
self
.
image_shape
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
,
label
=
label
,
char_ops
=
self
.
char_ops
,
loss_type
=
self
.
loss_type
)
else
:
outs
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
...
...
ppocr/data/rec/img_tools.py
浏览文件 @
1b190503
...
...
@@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape):
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
max_text_length
,
char_num
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
...
...
@@ -418,7 +419,7 @@ def srn_other_inputs(image_shape,
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
((
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
((
max_text_length
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
37
]
*
max_text_length
).
reshape
((
-
1
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
int
(
char_num
-
1
)
]
*
max_text_length
).
reshape
((
-
1
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
...
...
@@ -441,17 +442,18 @@ def process_image_srn(img,
loss_type
=
None
):
norm_img
=
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
char_num
=
char_ops
.
get_char_num
()
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
,
char_num
)
if
label
is
not
None
:
char_num
=
char_ops
.
get_char_num
()
text
=
char_ops
.
encode
(
label
)
if
len
(
text
)
==
0
or
len
(
text
)
>
max_text_length
:
return
None
else
:
if
loss_type
==
"srn"
:
text_padded
=
[
37
]
*
max_text_length
text_padded
=
[
int
(
char_num
-
1
)
]
*
max_text_length
for
i
in
range
(
len
(
text
)):
text_padded
[
i
]
=
text
[
i
]
lbl_weight
[
i
]
=
[
1.0
]
...
...
ppocr/modeling/backbones/rec_resnet50_fpn.py
浏览文件 @
1b190503
...
...
@@ -81,6 +81,23 @@ class ResNet():
num_filters
=
num_filters
[
block
],
stride
=
stride_list
[
block
]
if
i
==
0
else
1
,
name
=
conv_name
)
F
.
append
(
conv
)
else
:
for
block
in
range
(
len
(
depth
)):
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
if
i
==
0
and
block
!=
0
:
stride
=
(
2
,
1
)
else
:
stride
=
(
1
,
1
)
conv
=
self
.
basic_block
(
input
=
conv
,
num_filters
=
num_filters
[
block
],
stride
=
stride
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
)
F
.
append
(
conv
)
base
=
F
[
-
1
]
for
i
in
[
-
2
,
-
3
]:
...
...
ppocr/utils/character.py
浏览文件 @
1b190503
...
...
@@ -26,8 +26,6 @@ class CharacterOps(object):
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
max_text_len
=
config
[
'max_text_length'
]
if
self
.
loss_type
==
"srn"
and
self
.
character_type
!=
"en"
:
raise
Exception
(
"SRN can only support in character_type == en"
)
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
...
...
@@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops,
acc_num
=
0
img_num
=
0
char_num
=
char_ops
.
get_char_num
()
total_len
=
preds
.
shape
[
0
]
img_num
=
int
(
total_len
/
max_text_len
)
for
i
in
range
(
img_num
):
cur_label
=
[]
cur_pred
=
[]
for
j
in
range
(
max_text_len
):
if
labels
[
j
+
i
*
max_text_len
]
!=
37
:
#0
if
labels
[
j
+
i
*
max_text_len
]
!=
int
(
char_num
-
1
)
:
#0
cur_label
.
append
(
labels
[
j
+
i
*
max_text_len
][
0
])
else
:
break
...
...
@@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops,
elif
j
==
len
(
cur_label
)
and
j
==
max_text_len
:
acc_num
+=
1
break
elif
j
==
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
==
37
:
elif
j
==
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
==
int
(
char_num
-
1
)
:
acc_num
+=
1
break
acc
=
acc_num
*
1.0
/
img_num
...
...
tools/infer_rec.py
浏览文件 @
1b190503
...
...
@@ -140,12 +140,12 @@ def main():
preds
=
preds
.
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds
)
elif
loss_type
==
"srn"
:
c
ur_pred
=
[]
c
har_num
=
char_ops
.
get_char_num
()
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
probs
=
np
.
array
(
predict
[
1
])
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
37
)[
0
]
valid_ind
=
np
.
where
(
preds
!=
int
(
char_num
-
1
)
)[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录