Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
d8ed5400
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d8ed5400
编写于
4月 01, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add default signature
上级
33a0fc45
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
85 addition
and
63 deletion
+85
-63
demo/image-classification/create_module.py
demo/image-classification/create_module.py
+1
-1
demo/lac/create_module.py
demo/lac/create_module.py
+4
-1
demo/senta/create_module.py
demo/senta/create_module.py
+4
-1
demo/senta/processor.py
demo/senta/processor.py
+8
-9
demo/ssd/create_module.py
demo/ssd/create_module.py
+5
-1
paddle_hub/commands/run.py
paddle_hub/commands/run.py
+38
-37
paddle_hub/module/module.py
paddle_hub/module/module.py
+13
-9
paddle_hub/module/signature.py
paddle_hub/module/signature.py
+12
-4
未找到文件。
demo/image-classification/create_module.py
浏览文件 @
d8ed5400
...
...
@@ -48,7 +48,7 @@ def create_module(args):
# create paddle hub module
assets
=
[
"resources/label_list.txt"
]
sign1
=
hub
.
create_signature
(
"classification"
,
inputs
=
[
image
],
outputs
=
[
predition
])
"classification"
,
inputs
=
[
image
],
outputs
=
[
predition
]
,
for_predict
=
True
)
sign2
=
hub
.
create_signature
(
"feature_map"
,
inputs
=
[
image
],
outputs
=
[
feature_map
])
hub
.
create_module
(
...
...
demo/lac/create_module.py
浏览文件 @
d8ed5400
...
...
@@ -37,7 +37,10 @@ def create_module():
# create a module and save as hub_module_lac
sign
=
hub
.
create_signature
(
name
=
"lexical_analysis"
,
inputs
=
[
word
],
outputs
=
[
crf_decode
])
name
=
"lexical_analysis"
,
inputs
=
[
word
],
outputs
=
[
crf_decode
],
for_predict
=
True
)
hub
.
create_module
(
sign_arr
=
[
sign
],
module_dir
=
"hub_module_lac"
,
...
...
demo/senta/create_module.py
浏览文件 @
d8ed5400
...
...
@@ -42,7 +42,10 @@ def create_module():
# create a module
sign
=
hub
.
create_signature
(
name
=
"sentiment_classify"
,
inputs
=
[
data
],
outputs
=
[
pred
])
name
=
"sentiment_classify"
,
inputs
=
[
data
],
outputs
=
[
pred
],
for_predict
=
True
)
hub
.
create_module
(
sign_arr
=
[
sign
],
module_dir
=
"hub_module_senta"
,
...
...
demo/senta/processor.py
浏览文件 @
d8ed5400
import
os
import
io
import
paddle
import
paddle.fluid
as
fluid
import
paddle_hub
as
hub
import
numpy
as
np
import
os
import
io
from
paddle_hub
import
BaseProcessor
from
paddle_hub.hub_server
import
default_hub_server
from
paddle_hub.module.manager
import
default_module_manager
import
paddle_hub
as
hub
def
load_vocab
(
file_path
):
...
...
@@ -37,17 +36,17 @@ def get_predict_label(pos_prob):
return
label
,
key
class
Processor
(
BaseProcessor
):
class
Processor
(
hub
.
BaseProcessor
):
def
__init__
(
self
,
module
):
self
.
module
=
module
assets_path
=
self
.
module
.
helper
.
assets_path
()
word_dict_path
=
os
.
path
.
join
(
assets_path
,
"train.vocab"
)
self
.
word_dict
=
load_vocab
(
word_dict_path
)
path
=
default_module_manager
.
search_module
(
"lac"
)
path
=
hub
.
default_module_manager
.
search_module
(
"lac"
)
if
path
:
self
.
lac
=
hub
.
Module
(
module_dir
=
path
)
else
:
result
,
_
,
path
=
default_module_manager
.
install_module
(
"lac"
)
result
,
_
,
path
=
hub
.
default_module_manager
.
install_module
(
"lac"
)
assert
path
,
"can't found necessary module lac"
self
.
lac
=
hub
.
Module
(
module_dir
=
path
)
...
...
demo/ssd/create_module.py
浏览文件 @
d8ed5400
...
...
@@ -41,10 +41,14 @@ def create_module():
assets
=
[
"resources/label_list.txt"
]
sign
=
hub
.
create_signature
(
"object_detection"
,
inputs
=
[
image
],
outputs
=
[
nmsed_out
])
"object_detection"
,
inputs
=
[
image
],
outputs
=
[
nmsed_out
],
for_predict
=
True
)
hub
.
create_module
(
sign_arr
=
[
sign
],
module_dir
=
"hub_module_ssd"
,
module_info
=
"resources/module_info.yml"
,
exe
=
exe
,
processor
=
processor
.
Processor
,
assets
=
assets
)
...
...
paddle_hub/commands/run.py
浏览文件 @
d8ed5400
...
...
@@ -42,29 +42,10 @@ class RunCommand(BaseCommand):
# yapf: disable
self
.
add_arg
(
'--config'
,
str
,
None
,
"config file in yaml format"
)
self
.
add_arg
(
'--dataset'
,
str
,
None
,
"dataset be used"
)
self
.
add_arg
(
'--data'
,
str
,
None
,
"data be used"
)
self
.
add_arg
(
'--signature'
,
str
,
None
,
"signature to run"
)
# yapf: enable
def
_check_dataset
(
self
):
if
not
self
.
args
.
dataset
:
print
(
"Error! Lack of dataset file"
)
self
.
help
()
exit
(
1
)
if
not
utils
.
is_csv_file
(
self
.
args
.
dataset
):
print
(
"Error! Dataset file should in csv format"
)
self
.
help
()
exit
(
1
)
def
_check_config
(
self
):
if
not
self
.
args
.
config
:
print
(
"Error! Lack of config file"
)
self
.
help
()
exit
(
1
)
if
not
utils
.
is_yaml_file
(
self
.
args
.
config
):
print
(
"Error! Config file should in yaml format"
)
self
.
help
()
exit
(
1
)
def
exec
(
self
,
argv
):
if
not
argv
:
print
(
"ERROR: Please specify a key
\n
"
)
...
...
@@ -72,8 +53,6 @@ class RunCommand(BaseCommand):
return
False
module_name
=
argv
[
0
]
self
.
args
=
self
.
parser
.
parse_args
(
argv
[
1
:])
self
.
_check_dataset
()
self
.
_check_config
()
module_dir
=
default_module_manager
.
search_module
(
module_name
)
if
not
module_dir
:
...
...
@@ -88,30 +67,52 @@ class RunCommand(BaseCommand):
return
False
module
=
hub
.
Module
(
module_dir
=
module_dir
)
yaml_config
=
yaml_reader
.
read
(
self
.
args
.
config
)
if
not
module
.
default_signature
:
print
(
"ERROR! Module %s is not callable"
%
module_name
)
if
not
self
.
args
.
signature
:
self
.
args
.
signature
=
module
.
default_signature
().
name
self
.
args
.
signature
=
module
.
default_signature
.
name
# module processor check
module
.
check_processor
()
# data_format check
expect_data_format
=
module
.
processor
.
data_format
(
self
.
args
.
signature
)
input_data_format
=
yaml_config
[
'input_data'
]
assert
len
(
input_data_format
)
==
len
(
expect_data_format
)
for
key
,
value
in
expect_data_format
.
items
():
assert
key
in
input_data_format
assert
value
[
'type'
]
==
hub
.
DataType
.
type
(
input_data_format
[
key
][
'type'
])
# get data dict
origin_data
=
csv_reader
.
read
(
self
.
args
.
dataset
)
input_data
=
{}
for
key
,
value
in
yaml_config
[
'input_data'
].
items
():
input_data
[
key
]
=
origin_data
[
value
[
'key'
]]
if
self
.
args
.
data
:
input_data_key
=
list
(
expect_data_format
.
keys
())[
0
]
origin_data
=
{
input_data_key
:
[
self
.
args
.
data
]}
elif
self
.
args
.
dataset
:
origin_data
=
csv_reader
.
read
(
self
.
args
.
dataset
)
else
:
print
(
"ERROR! Please specify data to predict"
)
self
.
help
()
exit
(
1
)
# data_format check
if
not
self
.
args
.
config
:
assert
len
(
expect_data_format
)
==
1
origin_data_key
=
list
(
origin_data
.
keys
())[
0
]
input_data_key
=
list
(
expect_data_format
.
keys
())[
0
]
input_data
=
{
input_data_key
:
origin_data
[
origin_data_key
]}
config
=
{}
else
:
yaml_config
=
yaml_reader
.
read
(
self
.
args
.
config
)
if
len
(
expect_data_format
)
==
1
:
origin_data_key
=
list
(
origin_data
.
keys
())[
0
]
input_data_key
=
list
(
expect_data_format
.
keys
())[
0
]
input_data
=
{
input_data_key
:
origin_data
[
origin_data_key
]}
else
:
input_data_format
=
yaml_config
[
'input_data'
]
assert
len
(
input_data_format
)
==
len
(
expect_data_format
)
for
key
,
value
in
expect_data_format
.
items
():
assert
key
in
input_data_format
assert
value
[
'type'
]
==
hub
.
DataType
.
type
(
input_data_format
[
key
][
'type'
])
input_data
=
{}
for
key
,
value
in
yaml_config
[
'input_data'
].
items
():
input_data
[
key
]
=
origin_data
[
value
[
'key'
]]
config
=
yaml_config
.
get
(
"config"
,
{})
# run module with data
config
=
yaml_config
.
get
(
"config"
,
{})
print
(
module
(
sign_name
=
self
.
args
.
signature
,
data
=
input_data
,
**
config
))
...
...
paddle_hub/module/module.py
浏览文件 @
d8ed5400
...
...
@@ -224,6 +224,8 @@ class Module(object):
for
sign
in
signatures
:
if
sign
.
name
in
self
.
signatures
:
raise
"Error! signature array contains repeat signatrue %s"
%
sign
if
self
.
default_signature
is
None
and
sign
.
for_predict
:
self
.
default_signature
=
sign
self
.
signatures
[
sign
.
name
]
=
sign
def
_recovery_parameter
(
self
,
program
):
...
...
@@ -308,6 +310,12 @@ class Module(object):
feed_names
=
feed_names
,
fetch_names
=
fetch_names
)
# recover default signature
default_signature_name
=
utils
.
from_flexible_data_to_pyobj
(
self
.
desc
.
extra_info
.
map
.
data
[
'default_signature'
])
self
.
default_signature
=
self
.
signatures
[
default_signature_name
]
if
default_signature_name
else
None
# recover module info
module_info
=
self
.
desc
.
extra_info
.
map
.
data
[
'module_info'
]
self
.
name
=
utils
.
from_flexible_data_to_pyobj
(
...
...
@@ -362,6 +370,11 @@ class Module(object):
fetch_var
.
var_name
=
self
.
get_var_name_with_prefix
(
output
.
name
)
fetch_var
.
alias
=
fetch_names
[
index
]
# save default signature
utils
.
from_pyobj_to_flexible_data
(
self
.
default_signature
.
name
if
self
.
default_signature
else
None
,
extra_info
.
map
.
data
[
'default_signature'
])
# save module info
module_info
=
extra_info
.
map
.
data
[
'module_info'
]
module_info
.
type
=
module_desc_pb2
.
MAP
...
...
@@ -512,15 +525,6 @@ class Module(object):
def
get_var_name_with_prefix
(
self
,
var_name
):
return
self
.
get_name_prefix
()
+
var_name
def
parameters
(
self
):
pass
def
parameter_attrs
(
self
):
pass
def
default_signature
(
self
):
return
self
.
default_signature
def
_check_signatures
(
self
):
assert
self
.
signatures
,
"signature array should not be None"
...
...
paddle_hub/module/signature.py
浏览文件 @
d8ed5400
...
...
@@ -20,8 +20,13 @@ from paddle_hub.common.utils import to_list
class
Signature
:
def
__init__
(
self
,
name
,
inputs
,
outputs
,
feed_names
=
None
,
fetch_names
=
None
):
def
__init__
(
self
,
name
,
inputs
,
outputs
,
feed_names
=
None
,
fetch_names
=
None
,
for_predict
=
False
):
inputs
=
to_list
(
inputs
)
outputs
=
to_list
(
outputs
)
...
...
@@ -52,16 +57,19 @@ class Signature:
self
.
outputs
=
outputs
self
.
feed_names
=
feed_names
self
.
fetch_names
=
fetch_names
self
.
for_predict
=
for_predict
def
create_signature
(
name
=
"default"
,
inputs
=
[],
outputs
=
[],
feed_names
=
None
,
fetch_names
=
None
):
fetch_names
=
None
,
for_predict
=
False
):
return
Signature
(
name
=
name
,
inputs
=
inputs
,
outputs
=
outputs
,
feed_names
=
feed_names
,
fetch_names
=
fetch_names
)
fetch_names
=
fetch_names
,
for_predict
=
for_predict
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录