Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
6f8fe813
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f8fe813
编写于
1月 07, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update module, support dict input
上级
0bf02fec
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
80 addition
and
40 deletion
+80
-40
Senta/create_module.sh
Senta/create_module.sh
+1
-1
paddle_hub/module.py
paddle_hub/module.py
+61
-32
tests/test_module.py
tests/test_module.py
+18
-7
未找到文件。
Senta/create_module.sh
浏览文件 @
6f8fe813
python test_create_
hub
.py
--train_data_path
./data/train_data/corpus.train
--word_dict_path
./data/train.vocab
--mode
train
--model_path
./models
python test_create_
module
.py
--train_data_path
./data/train_data/corpus.train
--word_dict_path
./data/train.vocab
--mode
train
--model_path
./models
paddle_hub/module.py
浏览文件 @
6f8fe813
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -82,25 +84,31 @@ class Module(object):
# self.dict.setdefault(0)
# self._load_assets(module_dir)
#TODO(ZeyuChen): Need add register more signature to execute different
# implmentation
def
__call__
(
self
,
inputs
=
None
,
signature
=
None
):
""" Call default signature and return results
def
_construct_feed_dict
(
self
,
inputs
):
""" Construct feed dict according to user's inputs and module config.
"""
# TODO(ZeyuChen): add proto spec to check which task we need to run
# if it's NLP word embedding task, then do words preprocessing
# if it's image classification or image feature task do the other works
feed_dict
=
{}
for
k
in
inputs
:
if
k
in
self
.
feed_target_names
:
feed_dict
[
k
]
=
inputs
[
k
]
# if it's
word_ids_lod_tensor
=
self
.
_process_input
(
inputs
)
np_words_id
=
np
.
array
(
word_ids_lod_tensor
)
print
(
"word_ids_lod_tensor
\n
"
,
np_words_id
)
return
feed_dict
def
__call__
(
self
,
inputs
=
None
,
sign_name
=
"default"
):
""" Call default signature and return results
"""
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict
=
self
.
_construct_feed_dict
(
inputs
)
print
(
"feed_dict"
,
feed_dict
)
ret_numpy
=
self
.
config
.
return_numpy
()
print
(
"ret_numpy"
,
ret_numpy
)
results
=
self
.
exe
.
run
(
self
.
inference_program
,
feed
=
{
self
.
feed_target_names
[
0
]:
word_ids_lod_tensor
},
#feed={self.feed_target_names[0]: word_ids_lod_tensor},
feed
=
feed_dict
,
fetch_list
=
self
.
fetch_targets
,
return_numpy
=
False
)
# return_numpy=Flase is important
return_numpy
=
ret_numpy
)
print
(
"module fetch_target_names"
,
self
.
feed_target_names
)
print
(
"module fetch_targets"
,
self
.
fetch_targets
)
...
...
@@ -109,9 +117,15 @@ class Module(object):
return
np_result
def
get_vars
(
self
):
"""
Return variable list of the module program
"""
return
self
.
inference_program
.
list_vars
()
def
get_feed_var
(
self
,
key
,
signature
=
"default"
):
"""
Get feed variable according to variable key and signature
"""
for
var
in
self
.
inference_program
.
list_vars
():
if
var
.
name
==
self
.
config
.
feed_var_name
(
key
,
signature
):
return
var
...
...
@@ -119,6 +133,9 @@ class Module(object):
raise
Exception
(
"Can't find input var {}"
.
format
(
key
))
def
get_fetch_var
(
self
,
key
,
signature
=
"default"
):
"""
Get fetch variable according to variable key and signature
"""
for
var
in
self
.
inference_program
.
list_vars
():
if
var
.
name
==
self
.
config
.
fetch_var_name
(
key
,
signature
):
return
var
...
...
@@ -129,7 +146,7 @@ class Module(object):
return
self
.
inference_program
# for text sequence input, transform to lod tensor as paddle graph's input
def
_process_input
(
self
,
inputs
):
def
_pr
epr
ocess_input
(
self
,
inputs
):
# words id mapping and dealing with oov
# transform to lod tensor
seq
=
[]
...
...
@@ -167,17 +184,22 @@ class ModuleConfig(object):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
if
module_name
==
None
:
module_name
=
module_dir
.
split
(
"/"
)[
-
1
]
# initialize module config default value
self
.
desc
.
name
=
module_name
print
(
"desc.name="
,
self
.
desc
.
name
)
self
.
desc
.
contain_assets
=
True
print
(
"desc.signature="
,
self
.
desc
.
contain_assets
)
self
.
desc
.
return_numpy
=
False
# init dict
self
.
dict
=
defaultdict
(
int
)
self
.
dict
.
setdefault
(
0
)
def
get_dict
(
self
):
""" Return dictionary in Module"""
return
self
.
dict
def
load
(
self
):
"""load module config from module dir
"""
Load module config from module directory.
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
pb_path
=
os
.
path
.
join
(
self
.
module_dir
,
"module_desc.pb"
)
...
...
@@ -198,8 +220,7 @@ class ModuleConfig(object):
self
.
dict
[
w
]
=
int
(
w_id
)
def
dump
(
self
):
"""
save module_desc.proto first
""" Save Module configure file to disk.
"""
pb_path
=
os
.
path
.
join
(
self
.
module_dir
,
"module_desc.pb"
)
with
open
(
pb_path
,
"wb"
)
as
fo
:
...
...
@@ -213,6 +234,11 @@ class ModuleConfig(object):
w_id
=
self
.
dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
def
return_numpy
(
self
):
"""Return numpy or not according to the proto config.
"""
return
self
.
desc
.
return_numpy
def
save_dict
(
self
,
word_dict
,
dict_name
=
DICT_NAME
):
""" Save dictionary for NLP module
"""
...
...
@@ -223,10 +249,13 @@ class ModuleConfig(object):
# for w in word_dict:
# self.dict[w] = word_dict[w]
def
get_dict
(
self
):
return
self
.
dict
def
register_feed_signature
(
self
,
feed_desc
,
sign_name
=
"default"
):
""" Register feed signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for
k
in
feed_desc
:
feed
=
self
.
desc
.
sign2var
[
sign_name
].
feed_desc
.
add
()
...
...
@@ -234,6 +263,12 @@ class ModuleConfig(object):
feed
.
var_name
=
feed_desc
[
k
]
def
register_fetch_signature
(
self
,
fetch_desc
,
sign_name
=
"default"
):
""" Register fetch signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for
k
in
fetch_desc
:
fetch
=
self
.
desc
.
sign2var
[
sign_name
].
fetch_desc
.
add
()
...
...
@@ -241,12 +276,16 @@ class ModuleConfig(object):
fetch
.
var_name
=
fetch_desc
[
k
]
def
feed_var_name
(
self
,
key
,
sign_name
=
"default"
):
"""get module's feed/input variable name
"""
for
desc
in
self
.
desc
.
sign2var
[
sign_name
].
feed_desc
:
if
desc
.
key
==
key
:
return
desc
.
var_name
raise
Exception
(
"feed variable {} not found"
.
format
(
key
))
def
fetch_var_name
(
self
,
key
,
sign_name
=
"default"
):
"""get module's fetch/output variable name
"""
for
desc
in
self
.
desc
.
sign2var
[
sign_name
].
fetch_desc
:
if
desc
.
key
==
key
:
return
desc
.
var_name
...
...
@@ -278,13 +317,3 @@ class ModuleUtils(object):
# print("********************************")
# print(program)
# print("********************************")
if
__name__
==
"__main__"
:
url
=
"http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
m
=
Module
(
module_url
=
url
)
inputs
=
[[
"it"
,
"is"
,
"new"
],
[
"hello"
,
"world"
]]
#tensor = m._process_input(inputs)
#print(tensor)
result
=
m
(
inputs
)
print
(
result
)
tests/test_module.py
浏览文件 @
6f8fe813
...
...
@@ -17,14 +17,25 @@ import paddle_hub as hub
class
TestModule
(
unittest
.
TestCase
):
#TODO(ZeyuChen): add setup for test envrinoment prepration
def
test_word2vec_module_usage
(
self
):
url
=
"http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
module
=
Module
(
module_url
=
url
)
inputs
=
[[
"it"
,
"is"
,
"new"
],
[
"hello"
,
"world"
]]
tensor
=
module
.
_process_input
(
inputs
)
print
(
tensor
)
result
=
module
(
inputs
)
print
(
result
)
pass
# url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
# module = Module(module_url=url)
# inputs = [["it", "is", "new"], ["hello", "world"]]
# tensor = module._process_input(inputs)
# print(tensor)
# result = module(inputs)
# print(result)
def
test_senta_module_usage
(
self
):
pass
# m = Module(module_dir="./models/bow_net")
# inputs = [["外人", "爸妈", "翻车"], ["金钱", "电量"]]
# tensor = m._preprocess_input(inputs)
# print(tensor)
# result = m({"words": tensor})
# print(result)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录