Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
f96b873a
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f96b873a
编写于
8月 17, 2020
作者:
L
licx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify infer tools for sast
上级
c352e176
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
51 addition
and
9 deletion
+51
-9
configs/det/det_sast_icdar15_reader.yml
configs/det/det_sast_icdar15_reader.yml
+2
-4
configs/det/det_sast_totaltext_reader.yml
configs/det/det_sast_totaltext_reader.yml
+1
-1
ppocr/modeling/heads/det_sast_head.py
ppocr/modeling/heads/det_sast_head.py
+2
-2
tools/infer/predict_det.py
tools/infer/predict_det.py
+16
-0
tools/infer/utility.py
tools/infer/utility.py
+7
-0
tools/infer_det.py
tools/infer_det.py
+23
-2
未找到文件。
configs/det/det_sast_icdar15_reader.yml
浏览文件 @
f96b873a
...
@@ -20,7 +20,5 @@ EvalReader:
...
@@ -20,7 +20,5 @@ EvalReader:
TestReader
:
TestReader
:
reader_function
:
ppocr.data.det.dataset_traversal,EvalTestReader
reader_function
:
ppocr.data.det.dataset_traversal,EvalTestReader
process_function
:
ppocr.data.det.sast_process,SASTProcessTest
process_function
:
ppocr.data.det.sast_process,SASTProcessTest
infer_img
:
infer_img
:
./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg
img_set_dir
:
./train_data/icdar2015/text_localization/
max_side_len
:
1536
label_file_path
:
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval
:
True
configs/det/det_sast_totaltext_reader.yml
浏览文件 @
f96b873a
...
@@ -20,5 +20,5 @@ EvalReader:
...
@@ -20,5 +20,5 @@ EvalReader:
TestReader
:
TestReader
:
reader_function
:
ppocr.data.det.dataset_traversal,EvalTestReader
reader_function
:
ppocr.data.det.dataset_traversal,EvalTestReader
process_function
:
ppocr.data.det.sast_process,SASTProcessTest
process_function
:
ppocr.data.det.sast_process,SASTProcessTest
infer_img
:
infer_img
:
./train_data/afs/total_text/Images/Test/img623.jpg
max_side_len
:
768
max_side_len
:
768
ppocr/modeling/heads/det_sast_head.py
浏览文件 @
f96b873a
...
@@ -49,7 +49,7 @@ class SASTHead(object):
...
@@ -49,7 +49,7 @@ class SASTHead(object):
for
i
in
range
(
4
):
for
i
in
range
(
4
):
if
i
==
0
:
if
i
==
0
:
g
[
i
]
=
deconv_bn_layer
(
input
=
h
[
i
],
num_filters
=
num_outputs
[
i
+
1
],
act
=
None
,
name
=
'fpn_up_g0'
)
g
[
i
]
=
deconv_bn_layer
(
input
=
h
[
i
],
num_filters
=
num_outputs
[
i
+
1
],
act
=
None
,
name
=
'fpn_up_g0'
)
print
(
"g[{}] shape: {}"
.
format
(
i
,
g
[
i
].
shape
))
#
print("g[{}] shape: {}".format(i, g[i].shape))
else
:
else
:
g
[
i
]
=
fluid
.
layers
.
elementwise_add
(
x
=
g
[
i
-
1
],
y
=
h
[
i
])
g
[
i
]
=
fluid
.
layers
.
elementwise_add
(
x
=
g
[
i
-
1
],
y
=
h
[
i
])
g
[
i
]
=
fluid
.
layers
.
relu
(
g
[
i
])
g
[
i
]
=
fluid
.
layers
.
relu
(
g
[
i
])
...
@@ -58,7 +58,7 @@ class SASTHead(object):
...
@@ -58,7 +58,7 @@ class SASTHead(object):
g
[
i
]
=
conv_bn_layer
(
input
=
g
[
i
],
num_filters
=
num_outputs
[
i
],
g
[
i
]
=
conv_bn_layer
(
input
=
g
[
i
],
num_filters
=
num_outputs
[
i
],
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'fpn_up_g%d_1'
%
i
)
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'fpn_up_g%d_1'
%
i
)
g
[
i
]
=
deconv_bn_layer
(
input
=
g
[
i
],
num_filters
=
num_outputs
[
i
+
1
],
act
=
None
,
name
=
'fpn_up_g%d_2'
%
i
)
g
[
i
]
=
deconv_bn_layer
(
input
=
g
[
i
],
num_filters
=
num_outputs
[
i
+
1
],
act
=
None
,
name
=
'fpn_up_g%d_2'
%
i
)
print
(
"g[{}] shape: {}"
.
format
(
i
,
g
[
i
].
shape
))
#
print("g[{}] shape: {}".format(i, g[i].shape))
g
[
4
]
=
fluid
.
layers
.
elementwise_add
(
x
=
g
[
3
],
y
=
h
[
4
])
g
[
4
]
=
fluid
.
layers
.
elementwise_add
(
x
=
g
[
3
],
y
=
h
[
4
])
g
[
4
]
=
fluid
.
layers
.
relu
(
g
[
4
])
g
[
4
]
=
fluid
.
layers
.
relu
(
g
[
4
])
...
...
tools/infer/predict_det.py
浏览文件 @
f96b873a
...
@@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger
...
@@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger
logger
=
initial_logger
()
logger
=
initial_logger
()
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
import
cv2
import
cv2
from
ppocr.data.det.sast_process
import
SASTProcessTest
from
ppocr.data.det.east_process
import
EASTProcessTest
from
ppocr.data.det.east_process
import
EASTProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
from
ppocr.data.det.db_process
import
DBProcessTest
from
ppocr.postprocess.db_postprocess
import
DBPostProcess
from
ppocr.postprocess.db_postprocess
import
DBPostProcess
from
ppocr.postprocess.east_postprocess
import
EASTPostPocess
from
ppocr.postprocess.east_postprocess
import
EASTPostPocess
from
ppocr.postprocess.sast_postprocess
import
SASTPostProcess
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
import
math
import
math
...
@@ -52,6 +54,14 @@ class TextDetector(object):
...
@@ -52,6 +54,14 @@ class TextDetector(object):
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"cover_thresh"
]
=
args
.
det_east_cover_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_east_nms_thresh
self
.
postprocess_op
=
EASTPostPocess
(
postprocess_params
)
self
.
postprocess_op
=
EASTPostPocess
(
postprocess_params
)
elif
self
.
det_algorithm
==
"SAST"
:
self
.
preprocess_op
=
SASTProcessTest
(
preprocess_params
)
postprocess_params
[
"score_thresh"
]
=
args
.
det_sast_score_thresh
postprocess_params
[
"nms_thresh"
]
=
args
.
det_sast_nms_thresh
postprocess_params
[
"sample_pts_num"
]
=
args
.
det_sast_sample_pts_num
postprocess_params
[
"expand_scale"
]
=
args
.
det_sast_expand_scale
postprocess_params
[
"shrink_ratio_of_width"
]
=
args
.
det_sast_shrink_ratio_of_width
self
.
postprocess_op
=
SASTPostProcess
(
postprocess_params
)
else
:
else
:
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
@@ -120,8 +130,14 @@ class TextDetector(object):
...
@@ -120,8 +130,14 @@ class TextDetector(object):
if
self
.
det_algorithm
==
"EAST"
:
if
self
.
det_algorithm
==
"EAST"
:
outs_dict
[
'f_geo'
]
=
outputs
[
0
]
outs_dict
[
'f_geo'
]
=
outputs
[
0
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
elif
self
.
det_algorithm
==
'SAST'
:
outs_dict
[
'f_border'
]
=
outputs
[
0
]
outs_dict
[
'f_score'
]
=
outputs
[
1
]
outs_dict
[
'f_tco'
]
=
outputs
[
2
]
outs_dict
[
'f_tvo'
]
=
outputs
[
3
]
else
:
else
:
outs_dict
[
'maps'
]
=
outputs
[
0
]
outs_dict
[
'maps'
]
=
outputs
[
0
]
dt_boxes_list
=
self
.
postprocess_op
(
outs_dict
,
[
ratio_list
])
dt_boxes_list
=
self
.
postprocess_op
(
outs_dict
,
[
ratio_list
])
dt_boxes
=
dt_boxes_list
[
0
]
dt_boxes
=
dt_boxes_list
[
0
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
...
...
tools/infer/utility.py
浏览文件 @
f96b873a
...
@@ -53,6 +53,13 @@ def parse_args():
...
@@ -53,6 +53,13 @@ def parse_args():
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
#SAST parmas
parser
.
add_argument
(
"--det_sast_score_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_sast_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_sast_sample_pts_num"
,
type
=
float
,
default
=
2
)
parser
.
add_argument
(
"--det_sast_expand_scale"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--det_sast_shrink_ratio_of_width"
,
type
=
float
,
default
=
0.3
)
#params for text recognizer
#params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
...
...
tools/infer_det.py
浏览文件 @
f96b873a
...
@@ -66,6 +66,25 @@ def draw_det_res(dt_boxes, config, img, img_name):
...
@@ -66,6 +66,25 @@ def draw_det_res(dt_boxes, config, img, img_name):
cv2
.
imwrite
(
save_path
,
src_im
)
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
def
gen_im_detection
(
src_im
,
detections
):
"""
Generate image with detection results.
"""
im_detection
=
src_im
.
copy
()
h
,
w
,
_
=
im_detection
.
shape
thickness
=
int
(
max
((
h
+
w
)
/
2000
,
1
))
for
poly
in
detections
:
# Draw the first point
cv2
.
putText
(
im_detection
,
'0'
,
org
=
(
int
(
poly
[
0
,
0
]),
int
(
poly
[
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
thickness
,
color
=
(
255
,
0
,
0
),
thickness
=
thickness
)
cv2
.
polylines
(
im_detection
,
np
.
array
(
poly
).
reshape
((
1
,
-
1
,
2
)).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
0
,
0
,
255
),
thickness
=
thickness
)
return
im_detection
def
main
():
def
main
():
config
=
program
.
load_config
(
FLAGS
.
config
)
config
=
program
.
load_config
(
FLAGS
.
config
)
...
@@ -134,8 +153,10 @@ def main():
...
@@ -134,8 +153,10 @@ def main():
dic
=
{
'f_score'
:
outs
[
0
],
'f_geo'
:
outs
[
1
]}
dic
=
{
'f_score'
:
outs
[
0
],
'f_geo'
:
outs
[
1
]}
elif
config
[
'Global'
][
'algorithm'
]
==
'DB'
:
elif
config
[
'Global'
][
'algorithm'
]
==
'DB'
:
dic
=
{
'maps'
:
outs
[
0
]}
dic
=
{
'maps'
:
outs
[
0
]}
elif
config
[
'Global'
][
'algorithm'
]
==
'SAST'
:
dic
=
{
'f_score'
:
outs
[
0
],
'f_border'
:
outs
[
1
],
'f_tvo'
:
outs
[
2
],
'f_tco'
:
outs
[
3
]}
else
:
else
:
raise
Exception
(
"only support algorithm: ['EAST', 'DB']"
)
raise
Exception
(
"only support algorithm: ['EAST', 'DB'
, 'SAST'
]"
)
dt_boxes_list
=
postprocess
(
dic
,
ratio_list
)
dt_boxes_list
=
postprocess
(
dic
,
ratio_list
)
for
ino
in
range
(
img_num
):
for
ino
in
range
(
img_num
):
dt_boxes
=
dt_boxes_list
[
ino
]
dt_boxes
=
dt_boxes_list
[
ino
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录