Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
5fb3c419
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5fb3c419
编写于
9月 03, 2020
作者:
X
xiaoting
提交者:
GitHub
9月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #665 from tink2123/support_srn_inference
Support srn inference
上级
21af7660
5fc9fff9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
185 addition
and
27 deletion
+185
-27
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+9
-7
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+159
-15
tools/infer/utility.py
tools/infer/utility.py
+2
-1
tools/infer_rec.py
tools/infer_rec.py
+1
-1
tools/program.py
tools/program.py
+14
-3
未找到文件。
ppocr/modeling/architectures/rec_model.py
浏览文件 @
5fb3c419
...
...
@@ -136,7 +136,7 @@ class RecModel(object):
else
:
labels
=
None
loader
=
None
if
self
.
char_type
==
"ch"
and
self
.
infer_img
:
if
self
.
char_type
==
"ch"
and
self
.
infer_img
and
self
.
loss_type
!=
"srn"
:
image_shape
[
-
1
]
=
-
1
if
self
.
tps
!=
None
:
logger
.
info
(
...
...
@@ -172,16 +172,13 @@ class RecModel(object):
self
.
max_text_length
],
dtype
=
"float32"
)
feed_list
=
[
image
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
labels
=
{
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
}
return
image
,
labels
,
loader
def
__call__
(
self
,
mode
):
...
...
@@ -218,8 +215,13 @@ class RecModel(object):
if
self
.
loss_type
==
"ctc"
:
predict
=
fluid
.
layers
.
softmax
(
predict
)
if
self
.
loss_type
==
"srn"
:
raise
Exception
(
"Warning! SRN does not support export model currently"
)
return
[
image
,
labels
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}
]
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
else
:
predict
=
predicts
[
'predict'
]
...
...
tools/infer/predict_rec.py
浏览文件 @
5fb3c419
...
...
@@ -40,6 +40,7 @@ class TextRecognizer(object):
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
text_len
=
args
.
max_text_length
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
...
...
@@ -47,12 +48,15 @@ class TextRecognizer(object):
"use_space_char"
:
args
.
use_space_char
,
"max_text_length"
:
args
.
max_text_length
}
if
self
.
rec_algorithm
!=
"RARE"
:
if
self
.
rec_algorithm
in
[
"CRNN"
,
"Rosetta"
,
"STAR-Net"
]
:
char_ops_params
[
'loss_type'
]
=
'ctc'
self
.
loss_type
=
'ctc'
el
se
:
el
if
self
.
rec_algorithm
==
"RARE"
:
char_ops_params
[
'loss_type'
]
=
'attention'
self
.
loss_type
=
'attention'
elif
self
.
rec_algorithm
==
"SRN"
:
char_ops_params
[
'loss_type'
]
=
'srn'
self
.
loss_type
=
'srn'
self
.
char_ops
=
CharacterOps
(
char_ops_params
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
...
...
@@ -75,6 +79,83 @@ class TextRecognizer(object):
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
,
char_num
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
,
char_ops
=
None
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
char_num
=
char_ops
.
get_char_num
()
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
,
char_num
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
...
...
@@ -84,7 +165,7 @@ class TextRecognizer(object):
# Sorting can speed up the recognition process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
#
rec_res = []
#rec_res = []
rec_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
rec_batch_num
predict_time
=
0
...
...
@@ -98,20 +179,62 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
if
self
.
loss_type
!=
"srn"
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
,
self
.
char_ops
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
,
axis
=
0
)
norm_img_batch
=
norm_img_batch
.
copy
()
starttime
=
time
.
time
()
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
zero_copy_run
()
else
:
if
self
.
loss_type
==
"srn"
:
starttime
=
time
.
time
()
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
starttime
=
time
.
time
()
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
encoder_word_pos_list
=
fluid
.
core
.
PaddleTensor
(
encoder_word_pos_list
)
gsrm_word_pos_list
=
fluid
.
core
.
PaddleTensor
(
gsrm_word_pos_list
)
gsrm_slf_attn_bias1_list
=
fluid
.
core
.
PaddleTensor
(
gsrm_slf_attn_bias1_list
)
gsrm_slf_attn_bias2_list
=
fluid
.
core
.
PaddleTensor
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
gsrm_word_pos_list
]
self
.
predictor
.
run
(
inputs
)
else
:
starttime
=
time
.
time
()
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
zero_copy_run
()
else
:
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
if
self
.
loss_type
==
"ctc"
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
...
...
@@ -136,6 +259,26 @@ class TextRecognizer(object):
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
elif
self
.
loss_type
==
'srn'
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
probs
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
char_num
=
self
.
char_ops
.
get_char_num
()
preds
=
rec_idx_batch
.
reshape
(
-
1
)
elapse
=
time
.
time
()
-
starttime
predict_time
+=
elapse
total_preds
=
preds
.
copy
()
for
ino
in
range
(
int
(
len
(
rec_idx_batch
)
/
self
.
text_len
)):
preds
=
total_preds
[
ino
*
self
.
text_len
:(
ino
+
1
)
*
self
.
text_len
]
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
int
(
char_num
-
1
))[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
preds
=
preds
[:
valid_ind
[
-
1
]
+
1
]
preds_text
=
self
.
char_ops
.
decode
(
preds
)
rec_res
[
indices
[
beg_img_no
+
ino
]]
=
[
preds_text
,
score
]
else
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
predict_batch
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
...
...
@@ -170,6 +313,7 @@ def main(args):
continue
valid_image_file_list
.
append
(
image_file
)
img_list
.
append
(
img
)
try
:
rec_res
,
predict_time
=
text_recognizer
(
img_list
)
except
Exception
as
e
:
...
...
tools/infer/utility.py
浏览文件 @
5fb3c419
...
...
@@ -114,7 +114,8 @@ def create_predictor(args, mode):
predictor
=
create_paddle_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_tensor
=
predictor
.
get_input_tensor
(
input_names
[
0
])
for
name
in
input_names
:
input_tensor
=
predictor
.
get_input_tensor
(
name
)
output_names
=
predictor
.
get_output_names
()
output_tensors
=
[]
for
output_name
in
output_names
:
...
...
tools/infer_rec.py
浏览文件 @
5fb3c419
...
...
@@ -145,7 +145,7 @@ def main():
preds
=
preds
.
reshape
(
-
1
)
probs
=
np
.
array
(
predict
[
1
])
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
int
(
char_num
-
1
))[
0
]
valid_ind
=
np
.
where
(
preds
!=
int
(
char_num
-
1
))[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
...
...
tools/program.py
浏览文件 @
5fb3c419
...
...
@@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog):
with
fluid
.
unique_name
.
guard
():
func_infor
=
config
[
'Architecture'
][
'function'
]
model
=
create_module
(
func_infor
)(
params
=
config
)
image
,
outputs
=
model
(
mode
=
'export'
)
algorithm
=
config
[
'Global'
][
'algorithm'
]
if
algorithm
==
"SRN"
:
image
,
others
,
outputs
=
model
(
mode
=
'export'
)
else
:
image
,
outputs
=
model
(
mode
=
'export'
)
fetches_var_name
=
sorted
([
name
for
name
in
outputs
.
keys
()])
fetches_var
=
[
outputs
[
name
]
for
name
in
fetches_var_name
]
feeded_var_names
=
[
image
.
name
]
if
algorithm
==
"SRN"
:
others_var_names
=
sorted
([
name
for
name
in
others
.
keys
()])
feeded_var_names
=
[
image
.
name
]
+
others_var_names
else
:
feeded_var_names
=
[
image
.
name
]
target_vars
=
fetches_var
return
feeded_var_names
,
target_vars
,
fetches_var_name
...
...
@@ -409,7 +418,9 @@ def preprocess():
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录