Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
4402e629
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4402e629
编写于
11月 09, 2020
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
修正export_model里的bug,添加predict_det
上级
89e031f0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
18 deletion
+34
-18
tools/export_model.py
tools/export_model.py
+15
-9
tools/infer/predict_det.py
tools/infer/predict_det.py
+19
-9
未找到文件。
tools/export_model.py
浏览文件 @
4402e629
...
...
@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
argparse
import
paddle
...
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
from
tools.program
import
merge_config
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
...
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
32
,
None
],
dtype
=
'float32'
)
shape
=
[
None
,
3
,
640
,
640
],
dtype
=
'float32'
)
])
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
...
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
def
main
():
FLAGS
=
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
# build model
#for rec algorithm
#
for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
...
@@ -69,7 +72,10 @@ def main():
model
.
eval
()
model
=
Model
(
model
)
paddle
.
jit
.
save
(
model
,
FLAGS
.
output_path
)
save_path
=
'{}/{}'
.
format
(
FLAGS
.
output_path
,
config
[
'Architecture'
][
'model_type'
])
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
浏览文件 @
4402e629
...
...
@@ -22,7 +22,6 @@ import cv2
import
numpy
as
np
import
time
import
sys
import
paddle
import
tools.infer.utility
as
utility
...
...
@@ -39,7 +38,7 @@ class TextDetector(object):
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
pre_process_list
=
[{
'ResizeForTest'
:
{
'
Det
ResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
}
...
...
@@ -53,7 +52,7 @@ class TextDetector(object):
},
{
'ToCHWImage'
:
None
},
{
'
k
eepKeys'
:
{
'
K
eepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
}
}]
...
...
@@ -68,8 +67,9 @@ class TextDetector(object):
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
=
paddle
.
jit
.
load
(
args
.
det_model_dir
)
self
.
predictor
.
eval
()
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
# paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def
order_points_clockwise
(
self
,
pts
):
"""
...
...
@@ -133,11 +133,23 @@ class TextDetector(object):
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
preds
=
self
.
predictor
(
img
)
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
zero_copy_run
()
else
:
im
=
paddle
.
fluid
.
core
.
PaddleTensor
(
img
)
self
.
predictor
.
run
([
im
])
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
# preds = self.predictor(img)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
...
...
@@ -146,8 +158,6 @@ class TextDetector(object):
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
place
=
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
logger
=
get_logger
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录