Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a25545de
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
281
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a25545de
编写于
12月 23, 2019
作者:
W
wuzewu
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Module V2 support (#274)
* Add module v2
上级
b2dc77ed
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
408 addition
and
431 deletion
+408
-431
paddlehub/__init__.py
paddlehub/__init__.py
+1
-1
paddlehub/commands/install.py
paddlehub/commands/install.py
+17
-7
paddlehub/commands/run.py
paddlehub/commands/run.py
+29
-26
paddlehub/commands/show.py
paddlehub/commands/show.py
+0
-2
paddlehub/module/check_info.proto
paddlehub/module/check_info.proto
+3
-2
paddlehub/module/check_info_pb2.py
paddlehub/module/check_info_pb2.py
+26
-11
paddlehub/module/checker.py
paddlehub/module/checker.py
+23
-13
paddlehub/module/manager.py
paddlehub/module/manager.py
+78
-42
paddlehub/module/module.py
paddlehub/module/module.py
+231
-327
未找到文件。
paddlehub/__init__.py
浏览文件 @
a25545de
...
@@ -38,7 +38,7 @@ from .common.logger import logger
...
@@ -38,7 +38,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
default_hub_server
from
.common.hub_server
import
default_hub_server
from
.module.module
import
Module
,
create_module
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
from
.module.base_processor
import
BaseProcessor
from
.module.signature
import
Signature
,
create_signature
from
.module.signature
import
Signature
,
create_signature
from
.module.manager
import
default_module_manager
from
.module.manager
import
default_module_manager
...
...
paddlehub/commands/install.py
浏览文件 @
a25545de
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
argparse
import
os
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.manager
import
default_module_manager
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
...
@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print
(
"ERROR: Please specify a module name.
\n
"
)
print
(
"ERROR: Please specify a module name.
\n
"
)
self
.
help
()
self
.
help
()
return
False
return
False
extra
=
{
"command"
:
"install"
}
if
argv
[
0
].
endswith
(
"tar.gz"
)
or
argv
[
0
].
endswith
(
"phm"
):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_package
=
argv
[
0
],
extra
=
extra
)
elif
os
.
path
.
exists
(
argv
[
0
])
and
os
.
path
.
isdir
(
argv
[
0
]):
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_dir
=
argv
[
0
],
extra
=
extra
)
else
:
module_name
=
argv
[
0
]
module_name
=
argv
[
0
]
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
module_version
=
None
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
1
]
"=="
)[
1
]
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
module_name
=
module_name
if
"=="
not
in
module_name
else
module_name
.
split
(
"=="
)[
0
]
"=="
)[
0
]
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
module_name
=
module_name
,
module_version
=
module_version
,
extra
=
extra
)
print
(
tips
)
print
(
tips
)
return
True
return
True
...
...
paddlehub/commands/run.py
浏览文件 @
a25545de
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
...
@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if
not
result
:
if
not
result
:
return
None
return
None
return
hub
.
Module
(
module_dir
=
module_dir
)
return
hub
.
Module
(
directory
=
module_dir
[
0
]
)
def
add_module_config_arg
(
self
):
def
add_module_config_arg
(
self
):
configs
=
self
.
module
.
processor
.
configs
()
configs
=
self
.
module
.
processor
.
configs
()
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
...
@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def
add_module_input_arg
(
self
):
def
add_module_input_arg
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
self
.
arg_input_group
.
add_argument
(
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
'--input_file'
,
type
=
str
,
type
=
str
,
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
...
@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def
get_data
(
self
):
def
get_data
(
self
):
module_type
=
self
.
module
.
type
.
lower
()
module_type
=
self
.
module
.
type
.
lower
()
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
input_data
=
{}
input_data
=
{}
if
len
(
expect_data_format
)
==
1
:
if
len
(
expect_data_format
)
==
1
:
key
=
list
(
expect_data_format
.
keys
())[
0
]
key
=
list
(
expect_data_format
.
keys
())[
0
]
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
...
@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def
check_data
(
self
,
data
):
def
check_data
(
self
,
data
):
expect_data_format
=
self
.
module
.
processor
.
data_format
(
expect_data_format
=
self
.
module
.
processor
.
data_format
(
self
.
module
.
default_signature
.
name
)
self
.
module
.
default_signature
)
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
if
len
(
data
.
keys
())
!=
len
(
expect_data_format
.
keys
()):
print
(
print
(
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
...
@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return
False
return
False
# If the module is not executable, give an alarm and exit
# If the module is not executable, give an alarm and exit
if
not
self
.
module
.
default_signatur
e
:
if
not
self
.
module
.
is_runabl
e
:
print
(
"ERROR! Module %s is not executable."
%
module_name
)
print
(
"ERROR! Module %s is not executable."
%
module_name
)
return
False
return
False
if
self
.
module
.
code_version
==
"v2"
:
results
=
self
.
module
(
argv
[
1
:])
else
:
self
.
module
.
check_processor
()
self
.
module
.
check_processor
()
self
.
add_module_config_arg
()
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
self
.
add_module_input_arg
()
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
...
@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return
False
return
False
results
=
self
.
module
(
results
=
self
.
module
(
sign_name
=
self
.
module
.
default_signature
.
nam
e
,
sign_name
=
self
.
module
.
default_signatur
e
,
data
=
data
,
data
=
data
,
use_gpu
=
self
.
args
.
use_gpu
,
use_gpu
=
self
.
args
.
use_gpu
,
batch_size
=
self
.
args
.
batch_size
,
batch_size
=
self
.
args
.
batch_size
,
...
...
paddlehub/commands/show.py
浏览文件 @
a25545de
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
...
@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
default_module_manager
.
search_module
(
module_name
)
module_dir
=
(
os
.
path
.
join
(
cwd
,
module_name
),
None
)
if
not
module_dir
else
module_dir
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
if
not
module_dir
or
not
os
.
path
.
exists
(
module_dir
[
0
]):
print
(
"%s is not existed!"
%
module_name
)
print
(
"%s is not existed!"
%
module_name
)
return
True
return
True
...
...
paddlehub/module/check_info.proto
浏览文件 @
a25545de
...
@@ -50,6 +50,7 @@ message CheckInfo {
...
@@ -50,6 +50,7 @@ message CheckInfo {
string
paddle_version
=
1
;
string
paddle_version
=
1
;
string
hub_version
=
2
;
string
hub_version
=
2
;
string
module_proto_version
=
3
;
string
module_proto_version
=
3
;
repeated
FileInfo
file_infos
=
4
;
string
module_code_version
=
4
;
repeated
Requires
requires
=
5
;
repeated
FileInfo
file_infos
=
5
;
repeated
Requires
requires
=
6
;
};
};
paddlehub/module/check_info_pb2.py
浏览文件 @
a25545de
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
# source: check_info.proto
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
...
@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddlehub.module.checkinfo'
,
package
=
'paddlehub.module.checkinfo'
,
syntax
=
'proto3'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
serialized_pb
=
_b
(
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
c8\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x04
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x10\x63
heck_info.proto
\x12\x1a
paddlehub.module.checkinfo
\"\x85\x01\n\x08\x46
ileInfo
\x12\x11\n\t
file_name
\x18\x01
\x01
(
\t\x12\x33\n\x04
type
\x18\x02
\x01
(
\x0e\x32
%.paddlehub.module.checkinfo.FILE_TYPE
\x12\x0f\n\x07
is_need
\x18\x03
\x01
(
\x08\x12\x0b\n\x03
md5
\x18\x04
\x01
(
\t\x12\x13\n\x0b\x64\x65
scription
\x18\x05
\x01
(
\t\"\x84\x01\n\x08
Requires
\x12
>
\n\x0c
require_type
\x18\x01
\x01
(
\x0e\x32
(.paddlehub.module.checkinfo.REQUIRE_TYPE
\x12\x0f\n\x07
version
\x18\x02
\x01
(
\t\x12\x12\n\n
great_than
\x18\x03
\x01
(
\x08\x12\x13\n\x0b\x64\x65
scription
\x18\x04
\x01
(
\t\"\x
e5\x01\n\t
CheckInfo
\x12\x16\n\x0e
paddle_version
\x18\x01
\x01
(
\t\x12\x13\n\x0b
hub_version
\x18\x02
\x01
(
\t\x12\x1c\n\x14
module_proto_version
\x18\x03
\x01
(
\t\x12\x1b\n\x13
module_code_version
\x18\x04
\x01
(
\t\x12\x38\n\n
file_infos
\x18\x05
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.FileInfo
\x12\x36\n\x08
requires
\x18\x06
\x03
(
\x0b\x32
$.paddlehub.module.checkinfo.Requires*
\x1e\n\t
FILE_TYPE
\x12\x08\n\x04\x46
ILE
\x10\x00\x12\x07\n\x03\x44
IR
\x10\x01
*[
\n\x0c
REQUIRE_TYPE
\x12\x12\n\x0e
PYTHON_PACKAGE
\x10\x00\x12\x0e\n\n
HUB_MODULE
\x10\x01\x12\n\n\x06
SYSTEM
\x10\x02\x12\x0b\n\x07\x43
OMMAND
\x10\x03\x12\x0e\n\n
PY_VERSION
\x10\x04\x42\x02
H
\x03\x62\x06
proto3'
))
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
...
@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
serialized_start
=
5
22
,
serialized_start
=
5
51
,
serialized_end
=
5
52
,
serialized_end
=
5
81
,
)
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
_sym_db
.
RegisterEnumDescriptor
(
_FILE_TYPE
)
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
...
@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
serialized_start
=
5
54
,
serialized_start
=
5
83
,
serialized_end
=
6
45
,
serialized_end
=
6
74
,
)
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
_sym_db
.
RegisterEnumDescriptor
(
_REQUIRE_TYPE
)
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope
=
None
,
extension_scope
=
None
,
options
=
None
),
options
=
None
),
_descriptor
.
FieldDescriptor
(
_descriptor
.
FieldDescriptor
(
name
=
'
file_infos
'
,
name
=
'
module_code_version
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
file_infos
'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.
module_code_version
'
,
index
=
3
,
index
=
3
,
number
=
4
,
number
=
4
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'file_infos'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.file_infos'
,
index
=
4
,
number
=
5
,
type
=
11
,
type
=
11
,
cpp_type
=
10
,
cpp_type
=
10
,
label
=
3
,
label
=
3
,
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor
.
FieldDescriptor
(
_descriptor
.
FieldDescriptor
(
name
=
'requires'
,
name
=
'requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
full_name
=
'paddlehub.module.checkinfo.CheckInfo.requires'
,
index
=
4
,
index
=
5
,
number
=
5
,
number
=
6
,
type
=
11
,
type
=
11
,
cpp_type
=
10
,
cpp_type
=
10
,
label
=
3
,
label
=
3
,
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
...
@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges
=
[],
extension_ranges
=
[],
oneofs
=
[],
oneofs
=
[],
serialized_start
=
320
,
serialized_start
=
320
,
serialized_end
=
5
20
,
serialized_end
=
5
49
,
)
)
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
_FILEINFO
.
fields_by_name
[
'type'
].
enum_type
=
_FILE_TYPE
...
...
paddlehub/module/checker.py
浏览文件 @
a25545de
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
...
@@ -32,20 +32,22 @@ FILE_SEP = "/"
class
ModuleChecker
(
object
):
class
ModuleChecker
(
object
):
def
__init__
(
self
,
module_path
):
def
__init__
(
self
,
directory
):
self
.
module_path
=
module_path
self
.
_directory
=
directory
self
.
_pb_path
=
os
.
path
.
join
(
self
.
directory
,
CHECK_INFO_PB_FILENAME
)
def
generate_check_info
(
self
):
def
generate_check_info
(
self
):
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
=
check_info_pb2
.
CheckInfo
()
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
paddle_version
=
paddle
.
__version__
check_info
.
hub_version
=
hub_version
check_info
.
hub_version
=
hub_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_proto_version
=
module_proto_version
check_info
.
module_code_version
=
"v2"
file_infos
=
check_info
.
file_infos
file_infos
=
check_info
.
file_infos
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
module_path
)]
file_list
=
[
file
for
file
in
os
.
listdir
(
self
.
directory
)]
while
file_list
:
while
file_list
:
file
=
file_list
[
0
]
file
=
file_list
[
0
]
file_list
=
file_list
[
1
:]
file_list
=
file_list
[
1
:]
abs_path
=
os
.
path
.
join
(
self
.
module_path
,
file
)
abs_path
=
os
.
path
.
join
(
self
.
directory
,
file
)
if
os
.
path
.
isdir
(
abs_path
):
if
os
.
path
.
isdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
for
sub_file
in
os
.
listdir
(
abs_path
):
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
sub_file
=
os
.
path
.
join
(
file
,
sub_file
)
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
...
@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
type
=
check_info_pb2
.
FILE
file_info
.
is_need
=
True
file_info
.
is_need
=
True
with
open
(
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
),
with
open
(
self
.
pb_path
,
"wb"
)
as
file
:
"wb"
)
as
fi
:
file
.
write
(
check_info
.
SerializeToString
())
fi
.
write
(
check_info
.
SerializeToString
())
@
property
def
module_code_version
(
self
):
return
self
.
check_info
.
module_code_version
@
property
@
property
def
module_proto_version
(
self
):
def
module_proto_version
(
self
):
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
...
@@ -82,20 +87,25 @@ class ModuleChecker(object):
def
file_infos
(
self
):
def
file_infos
(
self
):
return
self
.
check_info
.
file_infos
return
self
.
check_info
.
file_infos
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
pb_path
(
self
):
return
self
.
_pb_path
def
check
(
self
):
def
check
(
self
):
result
=
True
result
=
True
self
.
check_info_pb_path
=
os
.
path
.
join
(
self
.
module_path
,
CHECK_INFO_PB_FILENAME
)
if
not
(
os
.
path
.
exists
(
self
.
check_info_pb_path
)
if
not
(
os
.
path
.
exists
(
self
.
pb_path
)
or
os
.
path
.
isfile
(
self
.
pb_path
)):
or
os
.
path
.
isfile
(
self
.
check_info_pb_path
)):
logger
.
warning
(
logger
.
warning
(
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
"This module lacks core file %s"
%
CHECK_INFO_PB_FILENAME
)
result
=
False
result
=
False
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
self
.
check_info
=
check_info_pb2
.
CheckInfo
()
try
:
try
:
with
open
(
self
.
check_info_
pb_path
,
"rb"
)
as
fi
:
with
open
(
self
.
pb_path
,
"rb"
)
as
fi
:
pb_string
=
fi
.
read
()
pb_string
=
fi
.
read
()
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
result
=
self
.
check_info
.
ParseFromString
(
pb_string
)
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
if
len
(
pb_string
)
==
0
or
(
result
is
not
None
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
...
@@ -182,7 +192,7 @@ class ModuleChecker(object):
for
file_info
in
self
.
file_infos
:
for
file_info
in
self
.
file_infos
:
file_type
=
file_info
.
type
file_type
=
file_info
.
type
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
file_info
.
file_name
.
replace
(
FILE_SEP
,
os
.
sep
)
file_path
=
os
.
path
.
join
(
self
.
module_path
,
file_path
)
file_path
=
os
.
path
.
join
(
self
.
directory
,
file_path
)
if
not
os
.
path
.
exists
(
file_path
):
if
not
os
.
path
.
exists
(
file_path
):
if
file_info
.
is_need
:
if
file_info
.
is_need
:
logger
.
warning
(
logger
.
warning
(
...
...
paddlehub/module/manager.py
浏览文件 @
a25545de
...
@@ -19,6 +19,7 @@ from __future__ import print_function
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
os
import
os
import
shutil
import
shutil
import
tarfile
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.common
import
srv_utils
from
paddlehub.common
import
srv_utils
...
@@ -77,10 +78,15 @@ class LocalModuleManager(object):
...
@@ -77,10 +78,15 @@ class LocalModuleManager(object):
return
self
.
modules_dict
.
get
(
module_name
,
None
)
return
self
.
modules_dict
.
get
(
module_name
,
None
)
def
install_module
(
self
,
def
install_module
(
self
,
module_name
,
module_name
=
None
,
module_dir
=
None
,
module_package
=
None
,
module_version
=
None
,
module_version
=
None
,
upgrade
=
False
,
upgrade
=
False
,
extra
=
None
):
extra
=
None
):
md5_value
=
installed_module_version
=
None
from_user_dir
=
True
if
module_dir
else
False
if
module_name
:
self
.
all_modules
(
update
=
True
)
self
.
all_modules
(
update
=
True
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
if
module_info
:
...
@@ -99,8 +105,9 @@ class LocalModuleManager(object):
...
@@ -99,8 +105,9 @@ class LocalModuleManager(object):
url
=
search_result
.
get
(
'url'
,
None
)
url
=
search_result
.
get
(
'url'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
md5_value
=
search_result
.
get
(
'md5'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
installed_module_version
=
search_result
.
get
(
'version'
,
None
)
if
not
url
or
(
module_version
is
not
None
and
installed_module_version
if
not
url
or
(
module_version
is
not
None
!=
module_version
)
or
(
name
!=
module_name
):
and
installed_module_version
!=
module_version
)
or
(
name
!=
module_name
):
if
default_hub_server
.
_server_check
()
is
False
:
if
default_hub_server
.
_server_check
()
is
False
:
tips
=
"Request Hub-Server unsuccessfully, please check your network."
tips
=
"Request Hub-Server unsuccessfully, please check your network."
else
:
else
:
...
@@ -123,13 +130,42 @@ class LocalModuleManager(object):
...
@@ -123,13 +130,42 @@ class LocalModuleManager(object):
delete_file
=
True
,
delete_file
=
True
,
print_progress
=
True
)
print_progress
=
True
)
if
module_package
:
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
split
(
file_names
[
0
])[
0
]
module_dir
=
os
.
path
.
join
(
hub
.
CACHE_HOME
,
module_dir
)
# remove cache
if
os
.
path
.
exists
(
module_dir
):
shutil
.
rmtree
(
module_dir
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
hub
.
CACHE_HOME
)
if
module_dir
:
if
module_dir
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
if
not
module_name
:
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
self
.
all_modules
(
update
=
False
)
module_info
=
self
.
modules_dict
.
get
(
module_name
,
None
)
if
module_info
:
module_dir
=
self
.
modules_dict
[
module_name
][
0
]
module_tag
=
module_name
if
not
module_version
else
'%s-%s'
%
(
module_name
,
module_version
)
tips
=
"Module %s already installed in %s"
%
(
module_tag
,
module_dir
)
return
True
,
tips
,
self
.
modules_dict
[
module_name
]
if
md5_value
:
with
open
(
os
.
path
.
join
(
MODULE_HOME
,
module_dir
,
"md5.txt"
),
"w"
)
as
fp
:
"w"
)
as
fp
:
fp
.
write
(
md5_value
)
fp
.
write
(
md5_value
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
save_path
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
if
os
.
path
.
exists
(
save_path
):
if
os
.
path
.
exists
(
save_path
):
shutil
.
rmtree
(
save_path
)
shutil
.
move
(
save_path
)
if
from_user_dir
:
shutil
.
copytree
(
module_dir
,
save_path
)
else
:
shutil
.
move
(
module_dir
,
save_path
)
shutil
.
move
(
module_dir
,
save_path
)
module_dir
=
save_path
module_dir
=
save_path
tips
=
"Successfully installed %s"
%
module_name
tips
=
"Successfully installed %s"
%
module_name
...
...
paddlehub/module/module.py
浏览文件 @
a25545de
...
@@ -21,6 +21,10 @@ import os
...
@@ -21,6 +21,10 @@ import os
import
time
import
time
import
sys
import
sys
import
functools
import
functools
import
inspect
import
importlib
import
tarfile
from
collections
import
defaultdict
from
shutil
import
copyfile
from
shutil
import
copyfile
import
paddle
import
paddle
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
...
@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from
paddlehub.common
import
utils
from
paddlehub.common
import
utils
from
paddlehub.common
import
paddle_helper
from
paddlehub.common
import
paddle_helper
from
paddlehub.common.
logger
import
logger
from
paddlehub.common.
dir
import
CACHE_HOME
from
paddlehub.common.lock
import
lock
from
paddlehub.common.lock
import
lock
from
paddlehub.common.downloader
import
default_downloader
from
paddlehub.common.logger
import
logger
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.common.dir
import
CONF_HOME
from
paddlehub.module
import
check_info_pb2
from
paddlehub.module
import
check_info_pb2
from
paddlehub.common.hub_server
import
CacheUpdater
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.manager
import
default_module_manager
from
paddlehub.module.checker
import
ModuleChecker
from
paddlehub.module.signature
import
Signature
,
create_signature
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.module.base_processor
import
BaseProcessor
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub.io.parser
import
yaml_parser
from
paddlehub
import
version
from
paddlehub
import
version
__all__
=
[
'Module'
,
'create_module'
]
# PaddleHub module dir name
# PaddleHub module dir name
ASSETS_DIRNAME
=
"assets"
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
MODEL_DIRNAME
=
"model"
...
@@ -52,67 +53,226 @@ PYTHON_DIR = "python"
...
@@ -52,67 +53,226 @@ PYTHON_DIR = "python"
PROCESSOR_NAME
=
"processor"
PROCESSOR_NAME
=
"processor"
# PaddleHub var prefix
# PaddleHub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
HUB_VAR_PREFIX
=
"@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX
=
"phm"
def
create_module
(
directory
,
name
,
author
,
email
,
module_type
,
summary
,
version
):
save_file_name
=
"{}-{}.{}"
.
format
(
name
,
version
,
HUB_PACKAGE_SUFFIX
)
# record module info and serialize
desc
=
module_desc_pb2
.
ModuleDesc
()
attr
=
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
module_type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
summary
,
module_info
.
map
.
data
[
'summary'
])
utils
.
from_pyobj_to_module_attr
(
version
,
module_info
.
map
.
data
[
'version'
])
module_desc_path
=
os
.
path
.
join
(
directory
,
"module_desc.pb"
)
with
open
(
module_desc_path
,
"wb"
)
as
f
:
f
.
write
(
desc
.
SerializeToString
())
# generate check info
checker
=
ModuleChecker
(
directory
)
checker
.
generate_check_info
()
# add __init__
module_init_1
=
os
.
path
.
join
(
directory
,
"__init__.py"
)
with
open
(
module_init_1
,
"a"
)
as
file
:
file
.
write
(
""
)
module_init_2
=
os
.
path
.
join
(
directory
,
"python"
,
"__init__.py"
)
with
open
(
module_init_2
,
"a"
)
as
file
:
file
.
write
(
""
)
# package the module
with
tarfile
.
open
(
save_file_name
,
"w:gz"
)
as
tar
:
for
dirname
,
_
,
files
in
os
.
walk
(
directory
):
for
file
in
files
:
tar
.
add
(
os
.
path
.
join
(
dirname
,
file
))
os
.
remove
(
module_desc_path
)
os
.
remove
(
checker
.
pb_path
)
os
.
remove
(
module_init_1
)
os
.
remove
(
module_init_2
)
class
Module
(
object
):
def
__new__
(
cls
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
module
=
None
if
cls
.
__name__
==
"Module"
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
)
elif
directory
:
module
=
cls
.
init_with_directory
(
directory
=
directory
)
elif
module_dir
:
logger
.
warning
(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if
isinstance
(
module_dir
,
list
)
or
isinstance
(
module_dir
,
tuple
):
directory
=
module_dir
[
0
]
version
=
module_dir
[
1
]
else
:
directory
=
module_dir
module
=
cls
.
init_with_directory
(
directory
=
directory
)
if
not
module
:
module
=
object
.
__new__
(
cls
)
CacheUpdater
(
name
,
version
).
start
()
return
module
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
):
if
not
directory
:
return
self
.
_code_version
=
"v2"
self
.
_directory
=
directory
self
.
module_desc_path
=
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
self
.
_desc
=
module_desc_pb2
.
ModuleDesc
()
with
open
(
self
.
module_desc_path
,
"rb"
)
as
file
:
self
.
_desc
.
ParseFromString
(
file
.
read
())
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
_name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
self
.
_author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
self
.
_author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
self
.
_version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
self
.
_type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
self
.
_summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
self
.
_initialize
()
@
classmethod
def
init_with_name
(
cls
,
name
,
version
=
None
):
fp_lock
=
open
(
os
.
path
.
join
(
CACHE_HOME
,
name
),
"a"
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
logger
.
info
(
tips
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
return
cls
.
init_with_directory
(
directory
=
module_dir
[
0
])
@
classmethod
def
init_with_directory
(
cls
,
directory
):
desc_file
=
os
.
path
.
join
(
directory
,
MODULE_DESC_PBNAME
)
checker
=
ModuleChecker
(
directory
)
checker
.
check
()
def
create_module
(
sign_arr
,
module_code_version
=
checker
.
module_code_version
module_dir
,
if
module_code_version
==
"v2"
:
processor
=
None
,
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
assets
=
None
,
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
module_info
=
None
,
sys
.
path
.
append
(
dirname
)
exe
=
None
,
pymodule
=
importlib
.
import_module
(
extra_info
=
None
):
"{}.python.module"
.
format
(
basename
))
sign_arr
=
utils
.
to_list
(
sign_arr
)
return
pymodule
.
HubModule
(
directory
=
directory
)
module
=
Module
(
return
ModuleV1
(
directory
=
directory
)
signatures
=
sign_arr
,
processor
=
processor
,
@
property
assets
=
assets
,
def
desc
(
self
):
module_info
=
module_info
,
return
self
.
_desc
extra_info
=
extra_info
)
module
.
serialize_to_path
(
path
=
module_dir
,
exe
=
exe
)
@
property
def
directory
(
self
):
return
self
.
_directory
@
property
def
author
(
self
):
return
self
.
_author
@
property
def
author_email
(
self
):
return
self
.
_author_email
@
property
def
summary
(
self
):
return
self
.
_summary
@
property
def
type
(
self
):
return
self
.
_type
@
property
def
version
(
self
):
return
self
.
_version
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
name_prefix
(
self
):
return
self
.
_name_prefix
@
property
def
code_version
(
self
):
return
self
.
_code_version
@
property
def
is_runable
(
self
):
return
False
def
_initialize
(
self
):
pass
class
ModuleHelper
(
object
):
class
ModuleHelper
(
object
):
def
__init__
(
self
,
module_dir
):
def
__init__
(
self
,
directory
):
self
.
module_dir
=
module_dir
self
.
directory
=
directory
def
module_desc_path
(
self
):
def
module_desc_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODULE_DESC_PBNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODULE_DESC_PBNAME
)
def
model_path
(
self
):
def
model_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
MODEL_DIRNAME
)
def
processor_path
(
self
):
def
processor_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
PYTHON_DIR
)
return
os
.
path
.
join
(
self
.
directory
,
PYTHON_DIR
)
def
processor_name
(
self
):
def
processor_name
(
self
):
return
PROCESSOR_NAME
return
PROCESSOR_NAME
def
assets_path
(
self
):
def
assets_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_DIRNAME
)
return
os
.
path
.
join
(
self
.
directory
,
ASSETS_DIRNAME
)
class
Module
(
object
):
class
ModuleV1
(
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
name
=
None
,
module_dir
=
None
,
signatures
=
None
,
module_info
=
None
,
assets
=
None
,
processor
=
None
,
extra_info
=
None
,
version
=
None
):
version
=
None
):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
if
not
directory
:
return
super
(
ModuleV1
,
self
).
__init__
(
name
,
directory
,
module_dir
,
version
)
self
.
_code_version
=
"v1"
self
.
program
=
None
self
.
program
=
None
self
.
assets
=
[]
self
.
assets
=
[]
self
.
helper
=
None
self
.
helper
=
None
self
.
signatures
=
{}
self
.
signatures
=
{}
self
.
default_signature
=
None
self
.
default_signature
=
None
self
.
module_info
=
None
self
.
processor
=
None
self
.
processor
=
None
self
.
extra_info
=
{}
if
extra_info
is
None
else
extra_info
self
.
extra_info
=
{}
if
not
isinstance
(
self
.
extra_info
,
dict
):
raise
TypeError
(
"The extra_info should be an instance of python dict"
)
# cache data
# cache data
self
.
last_call_name
=
None
self
.
last_call_name
=
None
...
@@ -120,62 +280,21 @@ class Module(object):
...
@@ -120,62 +280,21 @@ class Module(object):
self
.
cache_fetch_dict
=
None
self
.
cache_fetch_dict
=
None
self
.
cache_program
=
None
self
.
cache_program
=
None
fp_lock
=
open
(
os
.
path
.
join
(
CONF_HOME
,
'config.json'
))
self
.
helper
=
ModuleHelper
(
directory
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_EX
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
if
name
:
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
_init_with_name
(
name
=
name
,
version
=
version
)
self
.
helper
.
model_path
(),
executor
=
exe
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
for
block
in
self
.
program
.
blocks
:
elif
module_dir
:
for
op
in
block
.
ops
:
self
.
_init_with_module_file
(
module_dir
=
module_dir
[
0
])
if
"op_callstack"
in
op
.
all_attrs
():
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
op
.
_set_attr
(
"op_callstack"
,
[
""
])
name
=
module_dir
[
0
].
split
(
"/"
)[
-
1
]
self
.
_load_processor
()
if
len
(
module_dir
)
>
1
:
self
.
_load_assets
()
version
=
module_dir
[
1
]
self
.
_recover_from_desc
()
else
:
self
.
_generate_sign_attr
()
version
=
default_module_manager
.
search_module
(
name
)[
1
]
self
.
_generate_extra_info
()
elif
signatures
:
self
.
_restore_parameter
(
self
.
program
)
if
processor
:
self
.
_recover_variable_info
(
self
.
program
)
if
not
issubclass
(
processor
,
BaseProcessor
):
raise
TypeError
(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if
assets
:
self
.
assets
=
utils
.
to_list
(
assets
)
# for asset in assets:
# utils.check_path(assets)
self
.
processor
=
processor
self
.
_generate_module_info
(
module_info
)
self
.
_init_with_signature
(
signatures
=
signatures
)
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
else
:
lock
.
flock
(
fp_lock
,
lock
.
LOCK_UN
)
raise
ValueError
(
"Module initialized parameter is empty"
)
CacheUpdater
(
name
,
version
).
start
()
def
_init_with_name
(
self
,
name
,
version
=
None
):
log_msg
=
"Installing %s module"
%
name
if
version
:
log_msg
+=
"-%s"
%
version
logger
.
info
(
log_msg
)
extra
=
{
"command"
:
"install"
}
result
,
tips
,
module_dir
=
default_module_manager
.
install_module
(
module_name
=
name
,
module_version
=
version
,
extra
=
extra
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
logger
.
info
(
tips
)
self
.
_init_with_module_file
(
module_dir
[
0
])
def
_init_with_url
(
self
,
url
):
utils
.
check_url
(
url
)
result
,
tips
,
module_dir
=
default_downloader
.
download_file_and_uncompress
(
url
,
save_path
=
"."
)
if
not
result
:
logger
.
error
(
tips
)
raise
RuntimeError
(
tips
)
else
:
self
.
_init_with_module_file
(
module_dir
)
def
_dump_processor
(
self
):
def
_dump_processor
(
self
):
import
inspect
import
inspect
...
@@ -216,52 +335,6 @@ class Module(object):
...
@@ -216,52 +335,6 @@ class Module(object):
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
filepath
=
os
.
path
.
join
(
self
.
helper
.
assets_path
(),
file
)
self
.
assets
.
append
(
filepath
)
self
.
assets
.
append
(
filepath
)
def
_init_with_module_file
(
self
,
module_dir
):
checker
=
ModuleChecker
(
module_dir
)
checker
.
check
()
self
.
helper
=
ModuleHelper
(
module_dir
)
with
open
(
self
.
helper
.
module_desc_path
(),
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
program
,
_
,
_
=
fluid
.
io
.
load_inference_model
(
self
.
helper
.
model_path
(),
executor
=
exe
)
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
self
.
_load_processor
()
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
def
_init_with_signature
(
self
,
signatures
):
self
.
name_prefix
=
HUB_VAR_PREFIX
%
self
.
name
self
.
_process_signatures
(
signatures
)
self
.
_check_signatures
()
self
.
_generate_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
def
_init_with_program
(
self
,
program
):
pass
def
_process_signatures
(
self
,
signatures
):
self
.
signatures
=
{}
self
.
program
=
signatures
[
0
].
inputs
[
0
].
block
.
program
for
sign
in
signatures
:
if
sign
.
name
in
self
.
signatures
:
raise
ValueError
(
"Error! Signature array contains duplicated signatrues %s"
%
sign
)
if
self
.
default_signature
is
None
and
sign
.
for_predict
:
self
.
default_signature
=
sign
self
.
signatures
[
sign
.
name
]
=
sign
def
_restore_parameter
(
self
,
program
):
def
_restore_parameter
(
self
,
program
):
global_block
=
program
.
global_block
()
global_block
=
program
.
global_block
()
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
=
self
.
desc
.
attr
.
map
.
data
[
'param_attrs'
]
...
@@ -302,21 +375,6 @@ class Module(object):
...
@@ -302,21 +375,6 @@ class Module(object):
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
get_extra_info
,
key
=
key
)
self
.
get_extra_info
,
key
=
key
)
def
_generate_module_info
(
self
,
module_info
=
None
):
if
not
module_info
:
self
.
module_info
=
{}
else
:
if
not
utils
.
is_yaml_file
(
module_info
):
logger
.
critical
(
"Module info file should be yaml format"
)
exit
(
1
)
self
.
module_info
=
yaml_parser
.
parse
(
module_info
)
self
.
author
=
self
.
module_info
.
get
(
'author'
,
'UNKNOWN'
)
self
.
author_email
=
self
.
module_info
.
get
(
'author_email'
,
'UNKNOWN'
)
self
.
summary
=
self
.
module_info
.
get
(
'summary'
,
'UNKNOWN'
)
self
.
type
=
self
.
module_info
.
get
(
'type'
,
'UNKNOWN'
)
self
.
version
=
self
.
module_info
.
get
(
'version'
,
'UNKNOWN'
)
self
.
name
=
self
.
module_info
.
get
(
'name'
,
'UNKNOWN'
)
def
_generate_sign_attr
(
self
):
def
_generate_sign_attr
(
self
):
self
.
_check_signatures
()
self
.
_check_signatures
()
for
sign
in
self
.
signatures
:
for
sign
in
self
.
signatures
:
...
@@ -369,21 +427,21 @@ class Module(object):
...
@@ -369,21 +427,21 @@ class Module(object):
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
default_signature_name
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
desc
.
attr
.
map
.
data
[
'default_signature'
])
self
.
default_signature
=
self
.
signatures
[
self
.
default_signature
=
self
.
signatures
[
default_signature_name
]
if
default_signature_name
else
None
default_signature_name
]
.
name
if
default_signature_name
else
None
# recover module info
# recover module info
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
module_info
=
self
.
desc
.
attr
.
map
.
data
[
'module_info'
]
self
.
name
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
name
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'name'
])
module_info
.
map
.
data
[
'name'
])
self
.
author
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author'
])
module_info
.
map
.
data
[
'author'
])
self
.
author_email
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
author_email
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'author_email'
])
module_info
.
map
.
data
[
'author_email'
])
self
.
version
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
version
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'version'
])
module_info
.
map
.
data
[
'version'
])
self
.
type
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
type
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'type'
])
module_info
.
map
.
data
[
'type'
])
self
.
summary
=
utils
.
from_module_attr_to_pyobj
(
self
.
_
summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
module_info
.
map
.
data
[
'summary'
])
# recover extra info
# recover extra info
...
@@ -393,77 +451,9 @@ class Module(object):
...
@@ -393,77 +451,9 @@ class Module(object):
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
# recover name prefix
# recover name prefix
self
.
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
_name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
def
_generate_desc
(
self
):
# save fluid Parameter
attr
=
self
.
desc
.
attr
attr
.
type
=
module_desc_pb2
.
MAP
param_attrs
=
attr
.
map
.
data
[
'param_attrs'
]
param_attrs
.
type
=
module_desc_pb2
.
MAP
for
param
in
self
.
program
.
global_block
().
iter_parameters
():
param_attr
=
param_attrs
.
map
.
data
[
param
.
name
]
paddle_helper
.
from_param_to_module_attr
(
param
,
param_attr
)
# save Variable Info
var_infos
=
attr
.
map
.
data
[
'var_infos'
]
var_infos
.
type
=
module_desc_pb2
.
MAP
for
block
in
self
.
program
.
blocks
:
for
var
in
block
.
vars
.
values
():
var_info
=
var_infos
.
map
.
data
[
var
.
name
]
var_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
var
.
stop_gradient
,
var_info
.
map
.
data
[
'stop_gradient'
])
utils
.
from_pyobj_to_module_attr
(
block
.
idx
,
var_info
.
map
.
data
[
'block_id'
])
# save signarture info
for
key
,
sign
in
self
.
signatures
.
items
():
var
=
self
.
desc
.
sign2var
[
sign
.
name
]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
feed_names
=
sign
.
feed_names
fetch_names
=
sign
.
fetch_names
for
index
,
input
in
enumerate
(
sign
.
inputs
):
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
self
.
get_var_name_with_prefix
(
input
.
name
)
feed_var
.
alias
=
feed_names
[
index
]
for
index
,
output
in
enumerate
(
sign
.
outputs
):
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
self
.
get_var_name_with_prefix
(
output
.
name
)
fetch_var
.
alias
=
fetch_names
[
index
]
# save default signature
utils
.
from_pyobj_to_module_attr
(
self
.
default_signature
.
name
if
self
.
default_signature
else
None
,
attr
.
map
.
data
[
'default_signature'
])
# save name prefix
utils
.
from_pyobj_to_module_attr
(
self
.
name_prefix
,
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
# save module info
module_info
=
attr
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
utils
.
from_pyobj_to_module_attr
(
self
.
name
,
module_info
.
map
.
data
[
'name'
])
utils
.
from_pyobj_to_module_attr
(
self
.
version
,
module_info
.
map
.
data
[
'version'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author
,
module_info
.
map
.
data
[
'author'
])
utils
.
from_pyobj_to_module_attr
(
self
.
author_email
,
module_info
.
map
.
data
[
'author_email'
])
utils
.
from_pyobj_to_module_attr
(
self
.
type
,
module_info
.
map
.
data
[
'type'
])
utils
.
from_pyobj_to_module_attr
(
self
.
summary
,
module_info
.
map
.
data
[
'summary'
])
# save extra info
extra_info
=
attr
.
map
.
data
[
'extra_info'
]
extra_info
.
type
=
module_desc_pb2
.
MAP
for
key
,
value
in
self
.
extra_info
.
items
():
utils
.
from_pyobj_to_module_attr
(
value
,
extra_info
.
map
.
data
[
key
])
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
def
__call__
(
self
,
sign_name
,
data
,
use_gpu
=
False
,
batch_size
=
1
,
**
kwargs
):
self
.
check_processor
()
self
.
check_processor
()
...
@@ -525,6 +515,10 @@ class Module(object):
...
@@ -525,6 +515,10 @@ class Module(object):
if
not
self
.
processor
:
if
not
self
.
processor
:
raise
ValueError
(
"This Module is not callable!"
)
raise
ValueError
(
"This Module is not callable!"
)
@
property
def
is_runable
(
self
):
return
self
.
default_signature
!=
None
def
context
(
self
,
def
context
(
self
,
sign_name
=
None
,
sign_name
=
None
,
for_test
=
False
,
for_test
=
False
,
...
@@ -664,93 +658,3 @@ class Module(object):
...
@@ -664,93 +658,3 @@ class Module(object):
raise
ValueError
(
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
"All input and outputs variables in signature should come from the same Program"
)
)
def
serialize_to_path
(
self
,
path
=
None
,
exe
=
None
):
self
.
_check_signatures
()
self
.
_generate_desc
()
# create module path for saving
if
path
is
None
:
path
=
os
.
path
.
join
(
"."
,
self
.
name
)
self
.
helper
=
ModuleHelper
(
path
)
utils
.
mkdir
(
self
.
helper
.
module_dir
)
# create module pb
module_desc
=
module_desc_pb2
.
ModuleDesc
()
logger
.
info
(
"PaddleHub version = %s"
%
version
.
hub_version
)
logger
.
info
(
"PaddleHub Module proto version = %s"
%
version
.
module_proto_version
)
logger
.
info
(
"Paddle version = %s"
%
paddle
.
__version__
)
feeded_var_names
=
[
input
.
name
for
key
,
sign
in
self
.
signatures
.
items
()
for
input
in
sign
.
inputs
]
target_vars
=
[
output
for
key
,
sign
in
self
.
signatures
.
items
()
for
output
in
sign
.
outputs
]
feeded_var_names
=
list
(
set
(
feeded_var_names
))
target_vars
=
list
(
set
(
target_vars
))
# save inference program
program
=
self
.
program
.
clone
()
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
"op_callstack"
in
op
.
all_attrs
():
op
.
_set_attr
(
"op_callstack"
,
[
""
])
if
not
exe
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
fluid
.
io
.
save_inference_model
(
self
.
helper
.
model_path
(),
feeded_var_names
=
list
(
feeded_var_names
),
target_vars
=
list
(
target_vars
),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
self
.
get_name_prefix
()
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
self
.
get_var_name_with_prefix
(
old_name
)
block
.
_rename_var
(
old_name
,
new_name
)
utils
.
mkdir
(
self
.
helper
.
model_path
())
with
open
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
self
.
helper
.
model_path
()):
if
(
file
==
"__model__"
or
self
.
get_name_prefix
()
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
self
.
helper
.
model_path
(),
file
),
os
.
path
.
join
(
self
.
helper
.
model_path
(),
self
.
get_var_name_with_prefix
(
file
)))
# create processor file
if
self
.
processor
:
self
.
_dump_processor
()
# create assets
self
.
_dump_assets
()
# create check info
checker
=
ModuleChecker
(
self
.
helper
.
module_dir
)
checker
.
generate_check_info
()
# Serialize module_desc pb
module_pb
=
self
.
desc
.
SerializeToString
()
with
open
(
self
.
helper
.
module_desc_path
(),
"wb"
)
as
f
:
f
.
write
(
module_pb
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录