Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
5d27fa77
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d27fa77
编写于
1月 16, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add variable alias
上级
6decfdb9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
126 addition
and
19 deletion
+126
-19
paddle_hub/module.py
paddle_hub/module.py
+38
-6
paddle_hub/module_desc.proto
paddle_hub/module_desc.proto
+2
-0
paddle_hub/module_desc_pb2.py
paddle_hub/module_desc_pb2.py
+42
-10
paddle_hub/signature.py
paddle_hub/signature.py
+44
-3
未找到文件。
paddle_hub/module.py
浏览文件 @
5d27fa77
...
...
@@ -107,6 +107,26 @@ class Module(object):
if
op
.
has_attr
(
"is_test"
):
op
.
_set_attr
(
"is_test"
,
is_test
)
def
_process_input_output_key
(
module_desc
,
signature
):
signature
=
module_desc
.
sign2var
[
signature
]
feed_dict
=
{}
fetch_dict
=
{}
for
index
,
feed
in
enumerate
(
signature
.
feed_desc
):
if
feed
.
alias
!=
""
:
feed_dict
[
feed
.
alias
]
=
feed
.
var_name
feed_dict
[
index
]
=
feed
.
var_name
for
index
,
fetch
in
enumerate
(
signature
.
fetch_desc
):
if
fetch
.
alias
!=
""
:
fetch_dict
[
fetch
.
alias
]
=
fetch
.
var_name
fetch_dict
[
index
]
=
fetch
.
var_name
return
feed_dict
,
fetch_dict
self
.
config
=
ModuleConfig
(
self
.
module_dir
)
self
.
config
.
load
()
# load paddle inference model
place
=
fluid
.
CPUPlace
()
model_dir
=
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
...
...
@@ -114,15 +134,15 @@ class Module(object):
self
.
inference_program
,
self
.
feed_target_names
,
self
.
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
os
.
path
.
join
(
model_dir
,
sign_name
),
executor
=
self
.
exe
)
feed_dict
,
fetch_dict
=
_process_input_output_key
(
self
.
config
.
desc
,
sign_name
)
# remove feed fetch operator and variable
ModuleUtils
.
remove_feed_fetch_op
(
self
.
inference_program
)
# print("inference_program")
# print(self.inference_program)
print
(
"**feed_target_names**
\n
{}"
.
format
(
self
.
feed_target_names
))
print
(
"**fetch_targets**
\n
{}"
.
format
(
self
.
fetch_targets
))
self
.
config
=
ModuleConfig
(
self
.
module_dir
)
self
.
config
.
load
()
self
.
_process_parameter
()
name_generator_path
=
ModuleConfig
.
name_generator_path
(
self
.
module_dir
)
with
open
(
name_generator_path
,
"rb"
)
as
data
:
...
...
@@ -133,7 +153,15 @@ class Module(object):
_process_op_attr
(
program
=
program
,
is_test
=
False
)
_set_param_trainable
(
program
=
program
,
trainable
=
trainable
)
return
self
.
feed_target_names
,
self
.
fetch_targets
,
program
,
generator
for
key
,
value
in
feed_dict
.
items
():
var
=
program
.
global_block
().
var
(
value
)
feed_dict
[
key
]
=
var
for
key
,
value
in
fetch_dict
.
items
():
var
=
program
.
global_block
().
var
(
value
)
fetch_dict
[
key
]
=
var
return
feed_dict
,
fetch_dict
,
program
,
generator
def
get_inference_program
(
self
):
return
self
.
inference_program
...
...
@@ -315,13 +343,17 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
var
=
sign_map
[
sign
.
get_name
()]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
for
input
in
sign
.
get_inputs
():
feed_names
=
sign
.
get_feed_names
()
fetch_names
=
sign
.
get_fetch_names
()
for
index
,
input
in
enumerate
(
sign
.
get_inputs
()):
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
input
.
name
feed_var
.
alias
=
feed_names
[
index
]
for
output
in
sign
.
get_outputs
(
):
for
index
,
output
in
enumerate
(
sign
.
get_outputs
()
):
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
output
.
name
fetch_var
.
alias
=
fetch_names
[
index
]
# save inference program
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
...
...
paddle_hub/module_desc.proto
浏览文件 @
5d27fa77
...
...
@@ -21,11 +21,13 @@ package paddle_hub;
// Feed Variable Description
message
FeedDesc
{
string
var_name
=
1
;
string
alias
=
2
;
};
// Fetch Variable Description
message
FetchDesc
{
string
var_name
=
1
;
string
alias
=
2
;
};
// Module Variable
...
...
paddle_hub/module_desc_pb2.py
浏览文件 @
5d27fa77
...
...
@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddle_hub'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"
\x1c\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"\x1d\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xd9\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08\x12\x0f\n\x07
version
\x18\x05
\x01
(
\t\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"
+
\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\x12\r\n\x05\x61
lias
\x18\x02
\x01
(
\t\"
,
\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\x12\r\n\x05\x61
lias
\x18\x02
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xd9\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08\x12\x0f\n\x07
version
\x18\x05
\x01
(
\t\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
...
...
@@ -44,6 +44,22 @@ _FEEDDESC = _descriptor.Descriptor(
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'alias'
,
full_name
=
'paddle_hub.FeedDesc.alias'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
...
...
@@ -54,7 +70,7 @@ _FEEDDESC = _descriptor.Descriptor(
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
33
,
serialized_end
=
61
,
serialized_end
=
76
,
)
_FETCHDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -80,6 +96,22 @@ _FETCHDESC = _descriptor.Descriptor(
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'alias'
,
full_name
=
'paddle_hub.FetchDesc.alias'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
...
...
@@ -89,8 +121,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
63
,
serialized_end
=
9
2
,
serialized_start
=
78
,
serialized_end
=
12
2
,
)
_MODULEVAR
=
_descriptor
.
Descriptor
(
...
...
@@ -141,8 +173,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
9
4
,
serialized_end
=
18
9
,
serialized_start
=
12
4
,
serialized_end
=
21
9
,
)
_MODULEDESC_SIGN2VARENTRY
=
_descriptor
.
Descriptor
(
...
...
@@ -194,8 +226,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
3
3
9
,
serialized_end
=
4
0
9
,
serialized_start
=
3
6
9
,
serialized_end
=
4
3
9
,
)
_MODULEDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -296,8 +328,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
19
2
,
serialized_end
=
4
0
9
,
serialized_start
=
22
2
,
serialized_end
=
4
3
9
,
)
_MODULEVAR
.
fields_by_name
[
'fetch_desc'
].
message_type
=
_FETCHDESC
...
...
paddle_hub/signature.py
浏览文件 @
5d27fa77
...
...
@@ -20,11 +20,24 @@ from paddle_hub.utils import to_list
class
Signature
:
def
__init__
(
self
,
name
,
inputs
,
outputs
):
self
.
name
=
name
def
__init__
(
self
,
name
,
inputs
,
outputs
,
feed_names
=
None
,
fetch_names
=
None
):
inputs
=
to_list
(
inputs
)
outputs
=
to_list
(
outputs
)
if
not
feed_names
:
feed_names
=
[
""
]
*
len
(
inputs
)
feed_names
=
to_list
(
feed_names
)
assert
len
(
inputs
)
==
len
(
feed_names
),
"the length of feed_names must be same with inputs"
if
not
fetch_names
:
fetch_names
=
[
""
]
*
len
(
outputs
)
fetch_names
=
to_list
(
fetch_names
)
assert
len
(
outputs
)
==
len
(
fetch_names
),
"the length of fetch_names must be same with outputs"
self
.
name
=
name
for
item
in
inputs
:
assert
isinstance
(
item
,
...
...
@@ -37,6 +50,29 @@ class Signature:
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
feed_names
=
feed_names
self
.
fetch_names
=
fetch_names
# self.inputs_dict = {}
# for index, value in enumerate(inputs):
# self.inputs_dict[index] = value
# if feed_names:
# for index in range(len(feed_names)):
# key = feed_names[index]
# value = inputs[index]
# self.inputs_dict[key] = value
# self.outputs_dict = {}
# for index, value in enumerate(outputs):
# self.outputs_dict[index] = value
# if feed_names:
# for index in range(len(fetch_names)):
# key = fetch_names[index]
# value = outputs[index]
# self.outputs_dict[key] = value
def
get_name
(
self
):
return
self
.
name
...
...
@@ -47,7 +83,12 @@ class Signature:
def
get_outputs
(
self
):
return
self
.
outputs
def
get_feed_names
(
self
):
return
self
.
feed_names
def
get_fetch_names
(
self
):
return
self
.
fetch_names
def
create_signature
(
name
=
"default"
,
inputs
=
[],
outputs
=
[]):
def
create_signature
(
name
=
"default"
,
inputs
=
[],
outputs
=
[]):
return
Signature
(
name
=
name
,
inputs
=
inputs
,
outputs
=
outputs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录