Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
47d7cac1
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
47d7cac1
编写于
11月 28, 2022
作者:
C
chenjian
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix save_inference_model bug in paddlehub (#2143)
上级
52766374
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
49 deletion
+56
-49
paddlehub/module/module.py
paddlehub/module/module.py
+56
-49
未找到文件。
paddlehub/module/module.py
浏览文件 @
47d7cac1
...
@@ -37,7 +37,6 @@ from paddlehub.utils import utils
...
@@ -37,7 +37,6 @@ from paddlehub.utils import utils
class
InvalidHubModule
(
Exception
):
class
InvalidHubModule
(
Exception
):
def
__init__
(
self
,
directory
:
str
):
def
__init__
(
self
,
directory
:
str
):
self
.
directory
=
directory
self
.
directory
=
directory
...
@@ -200,11 +199,12 @@ class RunModule(object):
...
@@ -200,11 +199,12 @@ class RunModule(object):
for
key
,
_sub_module
in
self
.
sub_modules
().
items
():
for
key
,
_sub_module
in
self
.
sub_modules
().
items
():
try
:
try
:
sub_dirname
=
os
.
path
.
normpath
(
os
.
path
.
join
(
dirname
,
key
))
sub_dirname
=
os
.
path
.
normpath
(
os
.
path
.
join
(
dirname
,
key
))
_sub_module
.
save_inference_model
(
sub_dirname
,
_sub_module
.
save_inference_model
(
include_sub_modules
=
include_sub_modules
,
sub_dirname
,
model_filename
=
model_filename
,
include_sub_modules
=
include_sub_modules
,
params_filename
=
params_filename
,
model_filename
=
model_filename
,
combined
=
combined
)
params_filename
=
params_filename
,
combined
=
combined
)
except
:
except
:
utils
.
record_exception
(
'Failed to save sub module {}'
.
format
(
_sub_module
.
name
))
utils
.
record_exception
(
'Failed to save sub module {}'
.
format
(
_sub_module
.
name
))
...
@@ -231,14 +231,11 @@ class RunModule(object):
...
@@ -231,14 +231,11 @@ class RunModule(object):
if
not
self
.
_pretrained_model_path
:
if
not
self
.
_pretrained_model_path
:
raise
RuntimeError
(
'Module {} does not support exporting models in Paddle Inference format.'
.
format
(
raise
RuntimeError
(
'Module {} does not support exporting models in Paddle Inference format.'
.
format
(
self
.
name
))
self
.
name
))
elif
not
os
.
path
.
exists
(
self
.
_pretrained_model_path
):
elif
not
os
.
path
.
exists
(
self
.
_pretrained_model_path
)
and
not
os
.
path
.
exists
(
self
.
_pretrained_model_path
+
'.pdmodel'
):
log
.
logger
.
warning
(
'The model path of Module {} does not exist.'
.
format
(
self
.
name
))
log
.
logger
.
warning
(
'The model path of Module {} does not exist.'
.
format
(
self
.
name
))
return
return
model_filename
=
'__model__'
if
not
model_filename
else
model_filename
if
combined
:
params_filename
=
'__params__'
if
not
params_filename
else
params_filename
place
=
paddle
.
CPUPlace
()
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
...
@@ -253,21 +250,25 @@ class RunModule(object):
...
@@ -253,21 +250,25 @@ class RunModule(object):
if
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
_pretrained_model_path
,
'__params__'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
_pretrained_model_path
,
'__params__'
)):
_params_filename
=
'__params__'
_params_filename
=
'__params__'
if
_model_filename
is
not
None
and
_params_filename
is
not
None
:
program
,
feeded_var_names
,
target_vars
=
paddle
.
static
.
load_inference_model
(
self
.
_pretrained_model_path
,
executor
=
exe
,
model_filename
=
_model_filename
,
params_filename
=
_params_filename
,
)
else
:
program
,
feeded_var_names
,
target_vars
=
paddle
.
static
.
load_inference_model
(
self
.
_pretrained_model_path
,
executor
=
exe
)
program
,
feeded_var_names
,
target_vars
=
paddle
.
static
.
load_inference_model
(
global_block
=
program
.
global_block
()
dirname
=
self
.
_pretrained_model_path
,
feed_vars
=
[
global_block
.
var
(
item
)
for
item
in
feeded_var_names
]
executor
=
exe
,
model_filename
=
_model_filename
,
path_prefix
=
dirname
params_filename
=
_params_filename
,
if
os
.
path
.
isdir
(
dirname
):
)
path_prefix
=
os
.
path
.
join
(
dirname
,
'model'
)
paddle
.
static
.
save_inference_model
(
paddle
.
static
.
save_inference_model
(
dirname
=
dirname
,
path_prefix
,
feed_vars
=
feed_vars
,
fetch_vars
=
target_vars
,
executor
=
exe
,
program
=
program
)
main_program
=
program
,
executor
=
exe
,
feeded_var_names
=
feeded_var_names
,
target_vars
=
target_vars
,
model_filename
=
model_filename
,
params_filename
=
params_filename
)
log
.
logger
.
info
(
'Paddle Inference model saved in {}.'
.
format
(
dirname
))
log
.
logger
.
info
(
'Paddle Inference model saved in {}.'
.
format
(
dirname
))
...
@@ -337,17 +338,19 @@ class RunModule(object):
...
@@ -337,17 +338,19 @@ class RunModule(object):
save_file
=
os
.
path
.
join
(
dirname
,
'{}.onnx'
.
format
(
self
.
name
))
save_file
=
os
.
path
.
join
(
dirname
,
'{}.onnx'
.
format
(
self
.
name
))
program
,
inputs
,
outputs
=
paddle
.
static
.
load_inference_model
(
dirname
=
self
.
_pretrained_model_path
,
program
,
inputs
,
outputs
=
paddle
.
static
.
load_inference_model
(
model_filename
=
model_filename
,
dirname
=
self
.
_pretrained_model_path
,
params_filename
=
params_filename
,
model_filename
=
model_filename
,
executor
=
exe
)
params_filename
=
params_filename
,
executor
=
exe
)
paddle2onnx
.
program2onnx
(
program
=
program
,
paddle2onnx
.
program2onnx
(
scope
=
paddle
.
static
.
global_scope
(),
program
=
program
,
feed_var_names
=
inputs
,
scope
=
paddle
.
static
.
global_scope
(),
target_vars
=
outputs
,
feed_var_names
=
inputs
,
save_file
=
save_file
,
target_vars
=
outputs
,
**
kwargs
)
save_file
=
save_file
,
**
kwargs
)
class
Module
(
object
):
class
Module
(
object
):
...
@@ -387,13 +390,14 @@ class Module(object):
...
@@ -387,13 +390,14 @@ class Module(object):
from
paddlehub.server.server
import
CacheUpdater
from
paddlehub.server.server
import
CacheUpdater
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if
name
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
module
=
cls
.
init_with_name
(
version
=
version
,
name
=
name
,
source
=
source
,
version
=
version
,
update
=
update
,
source
=
source
,
branch
=
branch
,
update
=
update
,
ignore_env_mismatch
=
ignore_env_mismatch
,
branch
=
branch
,
**
kwargs
)
ignore_env_mismatch
=
ignore_env_mismatch
,
**
kwargs
)
CacheUpdater
(
"update_cache"
,
module
=
name
,
version
=
version
).
start
()
CacheUpdater
(
"update_cache"
,
module
=
name
,
version
=
version
).
start
()
elif
directory
:
elif
directory
:
module
=
cls
.
init_with_directory
(
directory
=
directory
,
**
kwargs
)
module
=
cls
.
init_with_directory
(
directory
=
directory
,
**
kwargs
)
...
@@ -485,12 +489,13 @@ class Module(object):
...
@@ -485,12 +489,13 @@ class Module(object):
manager
=
LocalModuleManager
()
manager
=
LocalModuleManager
()
user_module_cls
=
manager
.
search
(
name
,
source
=
source
,
branch
=
branch
)
user_module_cls
=
manager
.
search
(
name
,
source
=
source
,
branch
=
branch
)
if
not
user_module_cls
or
not
user_module_cls
.
version
.
match
(
version
):
if
not
user_module_cls
or
not
user_module_cls
.
version
.
match
(
version
):
user_module_cls
=
manager
.
install
(
name
=
name
,
user_module_cls
=
manager
.
install
(
version
=
version
,
name
=
name
,
source
=
source
,
version
=
version
,
update
=
update
,
source
=
source
,
branch
=
branch
,
update
=
update
,
ignore_env_mismatch
=
ignore_env_mismatch
)
branch
=
branch
,
ignore_env_mismatch
=
ignore_env_mismatch
)
directory
=
manager
.
_get_normalized_path
(
user_module_cls
.
name
)
directory
=
manager
.
_get_normalized_path
(
user_module_cls
.
name
)
...
@@ -555,7 +560,9 @@ def moduleinfo(name: str,
...
@@ -555,7 +560,9 @@ def moduleinfo(name: str,
_bases
.
append
(
_b
)
_bases
.
append
(
_b
)
_bases
.
append
(
_meta
)
_bases
.
append
(
_meta
)
_bases
=
tuple
(
_bases
)
_bases
=
tuple
(
_bases
)
wrap_cls
=
builtins
.
type
(
cls
.
__name__
,
_bases
,
dict
(
cls
.
__dict__
))
attr_dict
=
dict
(
cls
.
__dict__
)
attr_dict
.
pop
(
'__dict__'
,
None
)
wrap_cls
=
builtins
.
type
(
cls
.
__name__
,
_bases
,
attr_dict
)
wrap_cls
.
name
=
name
wrap_cls
.
name
=
name
wrap_cls
.
version
=
utils
.
Version
(
version
)
wrap_cls
.
version
=
utils
.
Version
(
version
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录