Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a25545de
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a25545de
编写于
12月 23, 2019
作者:
W
wuzewu
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Module V2 support (#274)
* Add module v2
上级
b2dc77ed
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
408 addition
and
431 deletion
+408
-431
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/commands/install.py
paddlehub/commands/install.py
+17
-7
paddlehub/commands/run.py
paddlehub/commands/run.py
+29
-26
paddlehub/commands/show.py
paddlehub/commands/show.py
+0
-2
paddlehub/module/check_info.proto
paddlehub/module/check_info.proto
+3
-2
paddlehub/module/check_info_pb2.py
paddlehub/module/check_info_pb2.py
+26
-11
paddlehub/module/checker.py
paddlehub/module/checker.py
+23
-13
paddlehub/module/manager.py
paddlehub/module/manager.py
+78
-42
paddlehub/module/module.py
paddlehub/module/module.py
+231
-327
未找到文件。
paddlehub/__init__.py
浏览文件 @
a25545de
...
...
@@ -38,7 +38,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
default_hub_server
from
.module.module
import
Module
,
create_module
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
from
.module.signature
import
Signature
,
create_signature
from
.module.manager
import
default_module_manager
...
...
paddlehub/commands/install.py
浏览文件 @
a25545de
...
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
import
argparse
import
os
from
paddlehub.common
import
utils
from
paddlehub.module.manager
import
default_module_manager
...
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print
(
"ERROR: Please specify a module name.
\n
"
)
self
.
help
()
return
False
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
if
argv
[
0
].
endswith
(
"tar.gz"
)
or
argv
[
0
].
endswith
(
"phm"
):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_package
=
argv
[
0
],
extra
=
extra
)
elif
os
.
path
.
exists
(
argv
[
0
])
and
os
.
path
.
isdir
(
argv
[
0
]):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_dir
=
argv
[
0
],
extra
=
extra
)
else
:
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
print
(
tips
)
return
True
...
...
paddlehub/commands/run.py
浏览文件 @
a25545de
...
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if
not
result
:
return
None
return
hub
.
Module
(
module_dir
=
module_dir
)
return
hub
.
Module
(
directory
=
module_dir
[
0
]
)
def
add_module_config_arg
(
self
):
configs
=
self
.
module
.
processor
.
configs
()
...
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def
add_module_input_arg
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
...
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def
get_data
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
input_data
=
{}
if
len
(
expect_data_format
)
==
1
:
key
=
list
(
expect_data_format
.
keys
())[
0
]
...
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def
check_data
(
self
,
data
):
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
print
(
...
...
@@ -236,35 +236,38 @@ class RunCommand(BaseCommand):
return
False
# If the module is not executable, give an alarm and exit
if
not
self
.
module
.
default_signatur
e
:
if
not
self
.
module
.
is_runabl
e
:
print
(
"ERROR! Module %s is not executable."
%
module_name
)
return
False
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
if
self
.
module
.
code_version
==
"v2"
:
results
=
self
.
module
(
argv
[
1
:])
else
:
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
if
not
argv
[
1
:]:
self
.
help
()
return
False
if
not
argv
[
1
:]:
self
.
help
()
return
False
self
.
args
=
self
.
parser
.
parse_args
(
argv
[
1
:])
self
.
args
=
self
.
parser
.
parse_args
(
argv
[
1
:])
config
=
self
.
get_config
()
data
=
self
.
get_data
()
config
=
self
.
get_config
()
data
=
self
.
get_data
()
try
:
self
.
check_data
(
data
)
except
DataFormatError
:
self
.
help
()
return
False
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signature
.
nam
e
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
**
config
)
try
:
self
.
check_data
(
data
)
except
DataFormatError
:
self
.
help
()
return
False
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signatur
e
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
**
config
)
if
six
.
PY2
:
try
:
...
...
paddlehub/commands/show.py
浏览文件 @
a25545de
...
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd
=
os
.
getcwd
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
(
os
.
path
.
join
(
cwd
,
module_name
),
None
)
if
not
module_dir
else
module_dir
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
print
(
"%s is not existed!"
%
module_name
)
return
True
...
...
paddlehub/module/check_info.proto
浏览文件 @
a25545de
...
...
@@ -50,6 +50,7 @@ message CheckInfo {
string
paddle_version
=
1
;
string
hub_version
=
2
;
string
module_proto_version
=
3
;
repeated
FileInfo
file_infos
=
4
;
repeated
Requires
requires
=
5
;
string
module_code_version
=
4
;
repeated
FileInfo
file_infos
=
5
;
repeated
Requires
requires
=
6
;
};
paddlehub/module/check_info_pb2.py
浏览文件 @
a25545de
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
...
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddlehub.module.checkinfo'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
c8\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x04
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
e5\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x1b\n\x13
module_code_version
\x18\x04
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x06
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
22
,
serialized_end
=
5
52
,
serialized_start
=
5
51
,
serialized_end
=
5
81
,
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
...
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
5
54
,
serialized_end
=
6
45
,
serialized_start
=
5
83
,
serialized_end
=
6
74
,
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
...
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'
file_infos
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
file_infos
'
,
name
=
'
module_code_version
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
module_code_version
'
,
index
=
3
,
number
=
4
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'file_infos'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.file_infos'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor
.
FieldDescriptor
(
name
=
'requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
index
=
4
,
number
=
5
,
index
=
5
,
number
=
6
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
...
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
320
,
serialized_end
=
5
20
,
serialized_end
=
5
49
,
)
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
...
...
paddlehub/module/checker.py
浏览文件 @
a25545de
...
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
class
ModuleChecker
(
object
):
def
__init__
(
self
,
module_path
):
self
.
module_path
=
module_path
def
__init__
(
self
,
directory
):
self
.
_directory
=
directory
self
.
_pb_path
=
os
.
path
.
join
(
self
.
directory
,
CHECK_INFO_PB_FILENAME
)
def
generate_check_info
(
self
):
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
hub_version
=
hub_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_code_version
=
"v2"
file_infos
=
check_info
.
file_infos
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
module_path
)]
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
directory
)]
while
file_list
:
file
=
file_list
[
0
]
file_list
=
file_list
[
1
:]
abs_path
=
os
.
path
.
join
(
self
.
module_path
,
file
)
abs_path
=
os
.
path
.
join
(
self
.
directory
,
file
)
if
os
.
path
.
isdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
...
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
is_need
=
True
with
open
(
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
),
"wb"
)
as
fi
:
fi
.
write
(
check_info
.
SerializeToString
())
with
open
(
self
.
pb_path
,
"wb"
)
as
file
:
file
.
write
(
check_info
.
SerializeToString
())
@
property
def
module_code_version
(
self
):
return
self
.
check_info
.
module_code_version
@
property
def
module_proto_version
(
self
):
...
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
def
file_infos
(
self
):
return
self
.
check_info
.
file_infos
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
pb_path
(
self
):
return
self
.
_pb_path
def
check
(
self
):
result
=
True
self
.
check_info_pb_path
=
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
)
if
not
(
os
.
path
.
exists
(
self
.
check_info_pb_path
)
or
os
.
path
.
isfile
(
self
.
check_info_pb_path
)):
if
not
(
os
.
path
.
exists
(
self
.
pb_path
)
or
os
.
path
.
isfile
(
self
.
pb_path
)):
logger
.
warning
(
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
result
=
False
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
try
:
with
open
(
self
.
check_info_
pb_path
,
"rb"
)
as
fi
:
with
open
(
self
.
pb_path
,
"rb"
)
as
fi
:
pb_string
=
fi
.
read
()
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
...
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
for
file_info
in
self
.
file_infos
:
file_type
=
file_info
.
type
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
os
.
path
.
join
(
self
.
module_path
,
file_path
)
file_path
=
os
.
path
.
join
(
self
.
directory
,
file_path
)
if
not
os
.
path
.
exists
(
file_path
):
if
file_info
.
is_need
:
logger
.
warning
(
...
...
paddlehub/module/manager.py
浏览文件 @
a25545de
...
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
os
import
shutil
import
tarfile
from
paddlehub.common
import
utils
from
paddlehub.common
import
srv_utils
...
...
@@ -77,15 +78,76 @@ class LocalModuleManager(object):
return
self
.
modules_dict
.
get
(
module_name
,
None
)
def
install_module
(
self
,
module_name
,
module_name
=
None
,
module_dir
=
None
,
module_package
=
None
,
module_version
=
None
,
upgrade
=
False
,
extra
=
None
):
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
if
not
module_version
or
module_version
==
self
.
modules_dict
[
module_name
][
1
]:
md5_value
=
installed_module_version
=
None
from_user_dir
=
True
if
module_dir
else
False
if
module_name
:
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
if
not
module_version
or
module_version
==
self
.
modules_dict
[
module_name
][
1
]:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
tips
=
"Module %s already installed in %s"
%
(
module_tag
,
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
tips
=
"Can't find module %s"
%
module_name
if
module_version
:
tips
+=
" with version %s"
%
module_version
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
return
False
,
tips
,
None
result
,
tips
,
module_zip_file
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
hub
.
CACHE_HOME
,
save_name
=
module_name
,
replace
=
True
,
print_progress
=
True
)
result
,
tips
,
module_dir
=
default_downloader
.
uncompress
(
file
=
module_zip_file
,
dirname
=
MODULE_HOME
,
delete_file
=
True
,
print_progress
=
True
)
if
module_package
:
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
split
(
file_names
[
0
])[
0
]
module_dir
=
os
.
path
.
join
(
hub
.
CACHE_HOME
,
module_dir
)
# remove cache
if
os
.
path
.
exists
(
module_dir
):
shutil
.
rmtree
(
module_dir
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
hub
.
CACHE_HOME
)
if
module_dir
:
if
not
module_name
:
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
self
.
all_modules
(
update
=
False
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
...
...
@@ -93,44 +155,18 @@ class LocalModuleManager(object):
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
search_result
=
hub
.
default_hub_server
.
get_module_url
(
module_name
,
version
=
module_version
,
extra
=
extra
)
name
=
search_result
.
get
(
'name'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
tips
=
"Can't find module %s"
%
module_name
if
module_version
:
tips
+=
" with version %s"
%
module_version
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
return
False
,
tips
,
None
result
,
tips
,
module_zip_file
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
hub
.
CACHE_HOME
,
save_name
=
module_name
,
replace
=
True
,
print_progress
=
True
)
result
,
tips
,
module_dir
=
default_downloader
.
uncompress
(
file
=
module_zip_file
,
dirname
=
MODULE_HOME
,
delete_file
=
True
,
print_progress
=
True
)
if
module_dir
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
if
md5_value
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
if
os
.
path
.
exists
(
save_path
):
shutil
.
rmtree
(
save_path
)
shutil
.
move
(
module_dir
,
save_path
)
shutil
.
move
(
save_path
)
if
from_user_dir
:
shutil
.
copytree
(
module_dir
,
save_path
)
else
:
shutil
.
move
(
module_dir
,
save_path
)
module_dir
=
save_path
tips
=
"Successfully installed %s"
%
module_name
if
installed_module_version
:
...
...
paddlehub/module/module.py
浏览文件 @
a25545de
...
...
@@ -21,6 +21,10 @@ import os
import
time
import
sys
import
functools
import
inspect
import
importlib
import
tarfile
from
collections
import
defaultdict
from
shutil
import
copyfile
import
paddle
...
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from
paddlehub.common
import
utils
from
paddlehub.common
import
paddle_helper
from
paddlehub.common.
logger
import
logger
from
paddlehub.common.
dir
import
CACHE_HOME
from
paddlehub.common.lock
import
lock
from
paddlehub.common.downloader
import
default_downloader
from
paddlehub.common.logger
import
logger
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.common.dir
import
CONF_HOME
from
paddlehub.module
import
check_info_pb2
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub
import
version
__all__
=
[
'Module'
,
'create_module'
]
# PaddleHub module dir name
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
...
...
@@ -52,67 +53,226 @@ PYTHON_DIR = "python"
PROCESSOR_NAME
=
"processor"
# PaddleHub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX
=
"phm"
def
create_module
(
directory
,
name
,
author
,
email
,
module_type
,
summary
,
version
):
save_file_name
=
"{}-{}.{}"
.
format
(
name
,
version
,
HUB_PACKAGE_SUFFIX
)
# record module info and serialize
desc
=
module_desc_pb2
.
ModuleDesc
()
attr
=
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
module_type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
summary
,
module_info
.
map
.
data
[
'summary'
])
utils
.
from_pyobj_to_module_attr
(
version
,
module_info
.
map
.
data
[
'version'
])
module_desc_path
=
os
.
path
.
join
(
directory
,
"module_desc.pb"
)
with
open
(
module_desc_path
,
"wb"
)
as
f
:
f
.
write
(
desc
.
SerializeToString
())
# generate check info
checker
=
ModuleChecker
(
directory
)
checker
.
generate_check_info
()
# add __init__
module_init_1
=
os
.
path
.
join
(
directory
,
"__init__.py"
)
with
open
(
module_init_1
,
"a"
)
as
file
:
file
.
write
(
""
)
module_init_2
=
os
.
path
.
join
(
directory
,
"python"
,
"__init__.py"
)
with
open
(
module_init_2
,
"a"
)
as
file
:
file
.
write
(
""
)
# package the module
with
tarfile
.
open
(
save_file_name
,
"w:gz"
)
as
tar
:
for
dirname
,
_
,
files
in
os
.
walk
(
directory
):
for
file
in
files
:
tar
.
add
(
os
.
path
.
join
(
dirname
,
file
))
os
.
remove
(
module_desc_path
)
os
.
remove
(
checker
.
pb_path
)
os
.
remove
(
module_init_1
)
os
.
remove
(
module_init_2
)
class
Module
(
object
):
def
__new__
(
cls
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
module
=
None
if
cls
.
__name__
==
"Module"
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
)
elif
directory
:
module
=
cls
.
init_with_directory
(
directory
=
directory
)
elif
module_dir
:
logger
.
warning
(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if
isinstance
(
module_dir
,
list
)
or
isinstance
(
module_dir
,
tuple
):
directory
=
module_dir
[
0
]
version
=
module_dir
[
1
]
else
:
directory
=
module_dir
module
=
cls
.
init_with_directory
(
directory
=
directory
)
if
not
module
:
module
=
object
.
__new__
(
cls
)
CacheUpdater
(
name
,
version
).
start
()
return
module
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
if
not
directory
:
return
self
.
_code_version
=
"v2"
self
.
_directory
=
directory
self
.
module_desc_path
=
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
self
.
_desc
=
module_desc_pb2
.
ModuleDesc
()
with
open
(
self
.
module_desc_path
,
"rb"
)
as
file
:
self
.
_desc
.
ParseFromString
(
file
.
read
())
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
_name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
_author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
_author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
_version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
_type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
_summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
self
.
_initialize
()
@
classmethod
def
init_with_name
(
cls
,
name
,
version
=
None
):
fp_lock
=
open
(
os
.
path
.
join
(
CACHE_HOME
,
name
),
"a"
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
logger
.
info
(
tips
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
return
cls
.
init_with_directory
(
directory
=
module_dir
[
0
])
def
create_module
(
sign_arr
,
module_dir
,
processor
=
None
,
assets
=
None
,
module_info
=
None
,
exe
=
None
,
extra_info
=
None
):
sign_arr
=
utils
.
to_list
(
sign_arr
)
module
=
Module
(
signatures
=
sign_arr
,
processor
=
processor
,
assets
=
assets
,
module_info
=
module_info
,
extra_info
=
extra_info
)
module
.
serialize_to_path
(
path
=
module_dir
,
exe
=
exe
)
@
classmethod
def
init_with_directory
(
cls
,
directory
):
desc_file
=
os
.
path
.
join
(
directory
,
MODULE_DESC_PBNAME
)
checker
=
ModuleChecker
(
directory
)
checker
.
check
()
module_code_version
=
checker
.
module_code_version
if
module_code_version
==
"v2"
:
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
sys
.
path
.
append
(
dirname
)
pymodule
=
importlib
.
import_module
(
"{}.python.module"
.
format
(
basename
))
return
pymodule
.
HubModule
(
directory
=
directory
)
return
ModuleV1
(
directory
=
directory
)
@
property
def
desc
(
self
):
return
self
.
_desc
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
author
(
self
):
return
self
.
_author
@
property
def
author_email
(
self
):
return
self
.
_author_email
@
property
def
summary
(
self
):
return
self
.
_summary
@
property
def
type
(
self
):
return
self
.
_type
@
property
def
version
(
self
):
return
self
.
_version
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
name_prefix
(
self
):
return
self
.
_name_prefix
@
property
def
code_version
(
self
):
return
self
.
_code_version
@
property
def
is_runable
(
self
):
return
False
def
_initialize
(
self
):
pass
class
ModuleHelper
(
object
):
def
__init__
(
self
,
module_dir
):
self
.
module_dir
=
module_dir
def
__init__
(
self
,
directory
):
self
.
directory
=
directory
def
module_desc_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODULE_DESC_PBNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
def
model_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODEL_DIRNAME
)
def
processor_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
PYTHON_DIR
)
return
os
.
path
.
join
(
self
.
directory
,
PYTHON_DIR
)
def
processor_name
(
self
):
return
PROCESSOR_NAME
def
assets_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
ASSETS_DIRNAME
)
class
Module
(
object
):
def
__init__
(
self
,
name
=
None
,
module_dir
=
None
,
signatures
=
None
,
module_info
=
None
,
assets
=
None
,
processor
=
None
,
extra_info
=
None
,
class
ModuleV1
(
Module
):
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
if
not
directory
:
return
super
(
ModuleV1
,
self
).
__init__
(
name
,
directory
,
module_dir
,
version
)
self
.
_code_version
=
"v1"
self
.
program
=
None
self
.
assets
=
[]
self
.
helper
=
None
self
.
signatures
=
{}
self
.
default_signature
=
None
self
.
module_info
=
None
self
.
processor
=
None
self
.
extra_info
=
{}
if
extra_info
is
None
else
extra_info
if
not
isinstance
(
self
.
extra_info
,
dict
):
raise
TypeError
(
"The extra_info should be an instance of python dict"
)
self
.
extra_info
=
{}
# cache data
self
.
last_call_name
=
None
...
...
@@ -120,62 +280,21 @@ class Module(object):
self
.
cache_fetch_dict
=
None
self
.
cache_program
=
None
fp_lock
=
open
(
os
.
path
.
join
(
CONF_HOME
,
'config.json'
))
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
if
name
:
self
.
_init_with_name
(
name
=
name
,
version
=
version
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
elif
module_dir
:
self
.
_init_with_module_file
(
module_dir
=
module_dir
[
0
])
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
name
=
module_dir
[
0
].
split
(
"/"
)[
-
1
]
if
len
(
module_dir
)
>
1
:
version
=
module_dir
[
1
]
else
:
version
=
default_module_manager
.
search_module
(
name
)[
1
]
elif
signatures
:
if
processor
:
if
not
issubclass
(
processor
,
BaseProcessor
):
raise
TypeError
(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if
assets
:
self
.
assets
=
utils
.
to_list
(
assets
)
# for asset in assets:
# utils.check_path(assets)
self
.
processor
=
processor
self
.
_generate_module_info
(
module_info
)
self
.
_init_with_signature
(
signatures
=
signatures
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
else
:
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
raise
ValueError
(
"Module initialized parameter is empty"
)
CacheUpdater
(
name
,
version
).
start
()
def
_init_with_name
(
self
,
name
,
version
=
None
):
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
logger
.
info
(
tips
)
self
.
_init_with_module_file
(
module_dir
[
0
])
def
_init_with_url
(
self
,
url
):
utils
.
check_url
(
url
)
result
,
tips
,
module_dir
=
default_downloader
.
download_file_and_uncompress
(
url
,
save_path
=
"."
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
self
.
_init_with_module_file
(
module_dir
)
self
.
helper
=
ModuleHelper
(
directory
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_dump_processor
(
self
):
import
inspect
...
...
@@ -216,52 +335,6 @@ class Module(object):
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
self
.
assets
.
append
(
filepath
)
def
_init_with_module_file
(
self
,
module_dir
):
checker
=
ModuleChecker
(
module_dir
)
checker
.
check
()
self
.
helper
=
ModuleHelper
(
module_dir
)
with
open
(
self
.
helper
.
module_desc_path
(),
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_init_with_signature
(
self
,
signatures
):
self
.
name_prefix
=
HUB_VAR_PREFIX
%
self
.
name
self
.
_process_signatures
(
signatures
)
self
.
_check_signatures
()
self
.
_generate_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
def
_init_with_program
(
self
,
program
):
pass
def
_process_signatures
(
self
,
signatures
):
self
.
signatures
=
{}
self
.
program
=
signatures
[
0
].
inputs
[
0
].
block
.
program
for
sign
in
signatures
:
if
sign
.
name
in
self
.
signatures
:
raise
ValueError
(
"Error! Signature array contains duplicated signatrues %s"
%
sign
)
if
self
.
default_signature
is
None
and
sign
.
for_predict
:
self
.
default_signature
=
sign
self
.
signatures
[
sign
.
name
]
=
sign
def
_restore_parameter
(
self
,
program
):
global_block
=
program
.
global_block
()
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
...
...
@@ -302,21 +375,6 @@ class Module(object):
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
get_extra_info
,
key
=
key
)
def
_generate_module_info
(
self
,
module_info
=
None
):
if
not
module_info
:
self
.
module_info
=
{}
else
:
if
not
utils
.
is_yaml_file
(
module_info
):
logger
.
critical
(
"Module info file should be yaml format"
)
exit
(
1
)
self
.
module_info
=
yaml_parser
.
parse
(
module_info
)
self
.
author
=
self
.
module_info
.
get
(
'author'
,
'UNKNOWN'
)
self
.
author_email
=
self
.
module_info
.
get
(
'author_email'
,
'UNKNOWN'
)
self
.
summary
=
self
.
module_info
.
get
(
'summary'
,
'UNKNOWN'
)
self
.
type
=
self
.
module_info
.
get
(
'type'
,
'UNKNOWN'
)
self
.
version
=
self
.
module_info
.
get
(
'version'
,
'UNKNOWN'
)
self
.
name
=
self
.
module_info
.
get
(
'name'
,
'UNKNOWN'
)
def
_generate_sign_attr
(
self
):
self
.
_check_signatures
()
for
sign
in
self
.
signatures
:
...
...
@@ -369,21 +427,21 @@ class Module(object):
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
default_signature
=
self
.
signatures
[
default_signature_name
]
if
default_signature_name
else
None
default_signature_name
]
.
name
if
default_signature_name
else
None
# recover module info
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
name
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
author
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
author_email
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
version
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
type
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
summary
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
# recover extra info
...
...
@@ -393,77 +451,9 @@ class Module(object):
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
# recover name prefix
self
.
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
def
_generate_desc
(
self
):
# save fluid Parameter
attr
=
self
.
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
param_attrs
=
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
.
type
=
module_desc_pb2
.
MAP
for
param
in
self
.
program
.
global_block
().
iter_parameters
():
param_attr
=
param_attrs
.
map
.
data
[
param
.
name
]
paddle_helper
.
from_param_to_module_attr
(
param
,
param_attr
)
# save Variable Info
var_infos
=
attr
.
map
.
data
[
'var_infos'
]
var_infos
.
type
=
module_desc_pb2
.
MAP
for
block
in
self
.
program
.
blocks
:
for
var
in
block
.
vars
.
values
():
var_info
=
var_infos
.
map
.
data
[
var
.
name
]
var_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
var
.
stop_gradient
,
var_info
.
map
.
data
[
'stop_gradient'
])
utils
.
from_pyobj_to_module_attr
(
block
.
idx
,
var_info
.
map
.
data
[
'block_id'
])
# save signarture info
for
key
,
sign
in
self
.
signatures
.
items
():
var
=
self
.
desc
.
sign2var
[
sign
.
name
]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
feed_names
=
sign
.
feed_names
fetch_names
=
sign
.
fetch_names
for
index
,
input
in
enumerate
(
sign
.
inputs
):
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
self
.
get_var_name_with_prefix
(
input
.
name
)
feed_var
.
alias
=
feed_names
[
index
]
for
index
,
output
in
enumerate
(
sign
.
outputs
):
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
self
.
get_var_name_with_prefix
(
output
.
name
)
fetch_var
.
alias
=
fetch_names
[
index
]
# save default signature
utils
.
from_pyobj_to_module_attr
(
self
.
default_signature
.
name
if
self
.
default_signature
else
None
,
attr
.
map
.
data
[
'default_signature'
])
# save name prefix
utils
.
from_pyobj_to_module_attr
(
self
.
name_prefix
,
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
# save module info
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
self
.
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
self
.
version
,
module_info
.
map
.
data
[
'version'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author_email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
self
.
type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
self
.
summary
,
module_info
.
map
.
data
[
'summary'
])
# save extra info
extra_info
=
attr
.
map
.
data
[
'extra_info'
]
extra_info
.
type
=
module_desc_pb2
.
MAP
for
key
,
value
in
self
.
extra_info
.
items
():
utils
.
from_pyobj_to_module_attr
(
value
,
extra_info
.
map
.
data
[
key
])
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
self
.
check_processor
()
...
...
@@ -525,6 +515,10 @@ class Module(object):
if
not
self
.
processor
:
raise
ValueError
(
"This Module is not callable!"
)
@
property
def
is_runable
(
self
):
return
self
.
default_signature
!=
None
def
context
(
self
,
sign_name
=
None
,
for_test
=
False
,
...
...
@@ -664,93 +658,3 @@ class Module(object):
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
)
def
serialize_to_path
(
self
,
path
=
None
,
exe
=
None
):
self
.
_check_signatures
()
self
.
_generate_desc
()
# create module path for saving
if
path
is
None
:
path
=
os
.
path
.
join
(
"."
,
self
.
name
)
self
.
helper
=
ModuleHelper
(
path
)
utils
.
mkdir
(
self
.
helper
.
module_dir
)
# create module pb
module_desc
=
module_desc_pb2
.
ModuleDesc
()
logger
.
info
(
"PaddleHub version = %s"
%
version
.
hub_version
)
logger
.
info
(
"PaddleHub Module proto version = %s"
%
version
.
module_proto_version
)
logger
.
info
(
"Paddle version = %s"
%
paddle
.
__version__
)
feeded_var_names
=
[
input
.
name
for
key
,
sign
in
self
.
signatures
.
items
()
for
input
in
sign
.
inputs
]
target_vars
=
[
output
for
key
,
sign
in
self
.
signatures
.
items
()
for
output
in
sign
.
outputs
]
feeded_var_names
=
list
(
set
(
feeded_var_names
))
target_vars
=
list
(
set
(
target_vars
))
# save inference program
program
=
self
.
program
.
clone
()
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
if
not
exe
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
fluid
.
io
.
save_inference_model
(
self
.
helper
.
model_path
(),
feeded_var_names
=
list
(
feeded_var_names
),
target_vars
=
list
(
target_vars
),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
self
.
get_name_prefix
()
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
self
.
get_var_name_with_prefix
(
old_name
)
block
.
_rename_var
(
old_name
,
new_name
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
self
.
helper
.
model_path
()):
if
(
file
==
"__model__"
or
self
.
get_name_prefix
()
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
file
),
os
.
path
.
join
(
self
.
helper
.
model_path
(),
self
.
get_var_name_with_prefix
(
file
)))
# create processor file
if
self
.
processor
:
self
.
_dump_processor
()
# create assets
self
.
_dump_assets
()
# create check info
checker
=
ModuleChecker
(
self
.
helper
.
module_dir
)
checker
.
generate_check_info
()
# Serialize module_desc pb
module_pb
=
self
.
desc
.
SerializeToString
()
with
open
(
self
.
helper
.
module_desc_path
(),
"wb"
)
as
f
:
f
.
write
(
module_pb
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录