Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
186683aa
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
186683aa
编写于
5月 02, 2020
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add build_transforms_v1 for old version paddlex
上级
4a2d8927
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
60 addition
and
6 deletion
+60
-6
paddlex/cv/models/load_model.py
paddlex/cv/models/load_model.py
+60
-6
未找到文件。
paddlex/cv/models/load_model.py
浏览文件 @
186683aa
...
...
@@ -28,7 +28,12 @@ def load_model(model_dir):
raise
Exception
(
"There's not model.yml in {}"
.
format
(
model_dir
))
with
open
(
osp
.
join
(
model_dir
,
"model.yml"
))
as
f
:
info
=
yaml
.
load
(
f
.
read
(),
Loader
=
yaml
.
Loader
)
status
=
info
[
'status'
]
if
'status'
in
info
:
status
=
info
[
'status'
]
elif
'save_method'
in
info
:
# 兼容老版本PaddleX
status
=
info
[
'save_method'
]
if
not
hasattr
(
paddlex
.
cv
.
models
,
info
[
'Model'
]):
raise
Exception
(
"There's no attribute {} in paddlex.cv.models"
.
format
(
...
...
@@ -40,7 +45,7 @@ def load_model(model_dir):
model
=
getattr
(
paddlex
.
cv
.
models
,
info
[
'Model'
])(
**
info
[
'_init_params'
])
if
status
==
"Normal"
or
\
status
==
"Prune"
:
status
==
"Prune"
or
status
==
"fluid.save"
:
startup_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
...
...
@@ -59,7 +64,7 @@ def load_model(model_dir):
fluid
.
io
.
set_program_state
(
model
.
test_prog
,
load_dict
)
elif
status
==
"Infer"
or
\
status
==
"Quant"
:
status
==
"Quant"
or
status
==
"fluid.save_inference_model"
:
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
model_dir
,
model
.
exe
,
params_filename
=
'__params__'
)
model
.
test_prog
=
prog
...
...
@@ -77,9 +82,15 @@ def load_model(model_dir):
to_rgb
=
True
else
:
to_rgb
=
False
model
.
test_transforms
=
build_transforms
(
model
.
model_type
,
info
[
'Transforms'
],
to_rgb
)
model
.
eval_transforms
=
copy
.
deepcopy
(
model
.
test_transforms
)
if
'BatchTransforms'
in
info
:
# 兼容老版本PaddleX模型
model
.
test_transforms
=
build_transforms_v1
(
model
.
model_type
,
info
[
'Transforms'
],
info
[
'BatchTransforms'
])
model
.
eval_transforms
=
copy
.
deepcopy
(
model
.
test_transforms
)
else
:
model
.
test_transforms
=
build_transforms
(
model
.
model_type
,
info
[
'Transforms'
],
to_rgb
)
model
.
eval_transforms
=
copy
.
deepcopy
(
model
.
test_transforms
)
if
'_Attributes'
in
info
:
for
k
,
v
in
info
[
'_Attributes'
].
items
():
...
...
@@ -109,3 +120,46 @@ def build_transforms(model_type, transforms_info, to_rgb=True):
eval_transforms
=
T
.
Compose
(
transforms
)
eval_transforms
.
to_rgb
=
to_rgb
return
eval_transforms
def
build_transforms_v1
(
model_type
,
transforms_info
,
batch_transforms_info
):
""" 老版本模型加载,仅支持PaddleX前端导出的模型
"""
logging
.
debug
(
"Use build_transforms_v1 to reconstruct transforms"
)
if
model_type
==
"classifier"
:
import
paddlex.cv.transforms.cls_transforms
as
T
elif
model_type
==
"detector"
:
import
paddlex.cv.transforms.det_transforms
as
T
elif
model_type
==
"segmenter"
:
import
paddlex.cv.transforms.seg_transforms
as
T
transforms
=
list
()
for
op_info
in
transforms_info
:
op_name
=
op_info
[
0
]
op_attr
=
op_info
[
1
]
if
op_name
==
'DecodeImage'
:
continue
if
op_name
==
'Permute'
:
continue
if
op_name
==
'ResizeByShort'
:
op_attr_new
=
dict
()
if
'short_size'
in
op_attr
:
op_attr_new
[
'short_size'
]
=
op_attr
[
'short_size'
]
else
:
op_attr_new
[
'short_size'
]
=
op_attr
[
'target_size'
]
op_attr_new
[
'max_size'
]
=
op_attr
.
get
(
'max_size'
,
-
1
)
op_attr
=
op_attr_new
if
op_name
.
startswith
(
'Arrange'
):
continue
if
not
hasattr
(
T
,
op_name
):
raise
Exception
(
"There's no operator named '{}' in transforms of {}"
.
format
(
op_name
,
model_type
))
transforms
.
append
(
getattr
(
T
,
op_name
)(
**
op_attr
))
if
model_type
==
"detector"
and
len
(
batch_transforms_info
)
>
0
:
op_name
=
batch_transforms_info
[
0
][
0
]
op_attr
=
batch_transforms_info
[
0
][
1
]
assert
op_name
==
"PaddingMiniBatch"
,
"Only PaddingMiniBatch transform is supported for batch transform"
padding
=
T
.
Padding
(
coarsest_stride
=
op_attr
[
'coarsest_stride'
])
transforms
.
append
(
padding
)
eval_transforms
=
T
.
Compose
(
transforms
)
return
eval_transforms
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录