Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
4402e629
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4402e629
编写于
11月 09, 2020
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
修正export_model里的bug,添加predict_det
上级
89e031f0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
18 deletion
+34
-18
tools/export_model.py
tools/export_model.py
+15
-9
tools/infer/predict_det.py
tools/infer/predict_det.py
+19
-9
未找到文件。
tools/export_model.py
浏览文件 @
4402e629
...
@@ -12,6 +12,13 @@
...
@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
argparse
import
argparse
import
paddle
import
paddle
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
from
tools.program
import
load_config
from
tools.program
import
merge_config
def
parse_args
():
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
# Please modify the 'shape' according to actual needs
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
32
,
None
],
dtype
=
'float32'
)
shape
=
[
None
,
3
,
640
,
640
],
dtype
=
'float32'
)
])
])
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
x
=
self
.
pre_model
(
inputs
)
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
def
main
():
def
main
():
FLAGS
=
parse_args
()
FLAGS
=
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
config
[
'Global'
])
# build model
# build model
#for rec algorithm
#
for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
@@ -69,7 +72,10 @@ def main():
...
@@ -69,7 +72,10 @@ def main():
model
.
eval
()
model
.
eval
()
model
=
Model
(
model
)
model
=
Model
(
model
)
paddle
.
jit
.
save
(
model
,
FLAGS
.
output_path
)
save_path
=
'{}/{}'
.
format
(
FLAGS
.
output_path
,
config
[
'Architecture'
][
'model_type'
])
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
浏览文件 @
4402e629
...
@@ -22,7 +22,6 @@ import cv2
...
@@ -22,7 +22,6 @@ import cv2
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
sys
import
sys
import
paddle
import
paddle
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
...
@@ -39,7 +38,7 @@ class TextDetector(object):
...
@@ -39,7 +38,7 @@ class TextDetector(object):
postprocess_params
=
{}
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
if
self
.
det_algorithm
==
"DB"
:
pre_process_list
=
[{
pre_process_list
=
[{
'ResizeForTest'
:
{
'
Det
ResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
}
}
...
@@ -53,7 +52,7 @@ class TextDetector(object):
...
@@ -53,7 +52,7 @@ class TextDetector(object):
},
{
},
{
'ToCHWImage'
:
None
'ToCHWImage'
:
None
},
{
},
{
'
k
eepKeys'
:
{
'
K
eepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
'keep_keys'
:
[
'image'
,
'shape'
]
}
}
}]
}]
...
@@ -68,8 +67,9 @@ class TextDetector(object):
...
@@ -68,8 +67,9 @@ class TextDetector(object):
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
=
paddle
.
jit
.
load
(
args
.
det_model_dir
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
utility
.
create_predictor
(
self
.
predictor
.
eval
()
args
,
'det'
,
logger
)
# paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
"""
...
@@ -133,11 +133,23 @@ class TextDetector(object):
...
@@ -133,11 +133,23 @@ class TextDetector(object):
return
None
,
0
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
starttime
=
time
.
time
()
preds
=
self
.
predictor
(
img
)
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
zero_copy_run
()
else
:
im
=
paddle
.
fluid
.
core
.
PaddleTensor
(
img
)
self
.
predictor
.
run
([
im
])
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
# preds = self.predictor(img)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
...
@@ -146,8 +158,6 @@ class TextDetector(object):
...
@@ -146,8 +158,6 @@ class TextDetector(object):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
args
=
utility
.
parse_args
()
place
=
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
logger
=
get_logger
()
logger
=
get_logger
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录