Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a8a0a8f1
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a8a0a8f1
编写于
1月 15, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add paddlehub version string, fix test case
上级
12e3dd4d
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
104 addition
and
162 deletion
+104
-162
example/sentiment-classification/sentiment_classify.py
example/sentiment-classification/sentiment_classify.py
+62
-64
paddle_hub/__init__.py
paddle_hub/__init__.py
+1
-0
paddle_hub/downloader.py
paddle_hub/downloader.py
+0
-4
paddle_hub/module.py
paddle_hub/module.py
+9
-11
paddle_hub/module_desc.proto
paddle_hub/module_desc.proto
+1
-4
paddle_hub/module_desc_pb2.py
paddle_hub/module_desc_pb2.py
+14
-62
tests/test_export_n_load_module.py
tests/test_export_n_load_module.py
+17
-17
未找到文件。
example/sentiment-classification/sentiment_classify.py
浏览文件 @
a8a0a8f1
...
...
@@ -199,8 +199,10 @@ def finetune_net(train_reader,
module_dir
=
os
.
path
.
join
(
save_dirname
,
network_name
)
module
=
hub
.
Module
(
module_dir
=
module_dir
)
feed_list
,
fetch_list
,
program
=
module
(
sign_name
=
"default"
,
trainable
=
True
)
feed_list
,
fetch_list
,
program
,
generator
=
module
(
sign_name
=
"default"
,
trainable
=
True
)
with
fluid
.
program_guard
(
main_program
=
program
):
with
fluid
.
unique_name
.
guard
(
generator
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
# data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config
...
...
@@ -256,14 +258,10 @@ def finetune_net(train_reader,
print
(
"[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f"
%
(
pass_id
,
avg_acc
,
avg_cost
))
# # save the model
# module_dir = os.path.join(save_dirname, network_name)
# signature = hub.create_signature(
# "default", inputs=[data], outputs=[sent_emb])
# hub.create_module(
# sign_arr=signature,
# program=fluid.default_main_program(),
# path=module_dir)
# save the model
model_dir
=
os
.
path
.
join
(
save_dirname
,
network_name
+
"_finetune"
)
fluid
.
io
.
save_persistables
(
executor
=
exe
,
dirname
=
model_dir
,
main_program
=
None
)
def
eval_net
(
test_reader
,
use_gpu
,
model_path
=
None
):
...
...
paddle_hub/__init__.py
浏览文件 @
a8a0a8f1
...
...
@@ -24,3 +24,4 @@ from paddle_hub.module import ModuleUtils
from
paddle_hub.module
import
create_module
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub.signature
import
create_signature
from
paddle_hub.version
import
__version__
paddle_hub/downloader.py
浏览文件 @
a8a0a8f1
...
...
@@ -56,7 +56,6 @@ def md5file(fname):
def
download_and_uncompress
(
url
,
save_name
=
None
):
module_name
=
url
.
split
(
"/"
)[
-
2
]
dirname
=
os
.
path
.
join
(
MODULE_HOME
,
module_name
)
print
(
"download to dir"
,
dirname
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
...
...
@@ -115,6 +114,3 @@ if __name__ == "__main__":
module_path
=
download_and_uncompress
(
link
)
print
(
"module path"
,
module_path
)
# dl = DownloadManager()
# dl.download_and_uncompress(link, "./tmp")
paddle_hub/module.py
浏览文件 @
a8a0a8f1
...
...
@@ -29,6 +29,7 @@ from paddle_hub.downloader import download_and_uncompress
from
paddle_hub
import
module_desc_pb2
from
paddle_hub.signature
import
Signature
from
paddle_hub.utils
import
to_list
from
paddle_hub.version
import
__version__
__all__
=
[
"Module"
,
"ModuleConfig"
,
"ModuleUtils"
]
...
...
@@ -259,13 +260,14 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
# create module path for saving
mkdir
(
module_dir
)
module
=
module_desc_pb2
.
ModuleDesc
()
module_desc
=
module_desc_pb2
.
ModuleDesc
()
module_desc
.
version
=
__version__
program
=
program
.
clone
()
if
word_dict
is
None
:
module
.
contain_assets
=
False
module
_desc
.
contain_assets
=
False
else
:
module
.
contain_assets
=
True
module
_desc
.
contain_assets
=
True
with
open
(
ModuleConfig
.
assets_dict_path
(
module_dir
),
"w"
)
as
fo
:
for
w
in
word_dict
:
w_id
=
word_dict
[
w
]
...
...
@@ -301,7 +303,7 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
pickle
.
dump
(
param_arr
,
fo
)
# save signarture info
sign_map
=
module
.
sign2var
sign_map
=
module
_desc
.
sign2var
sign_arr
=
to_list
(
sign_arr
)
for
sign
in
sign_arr
:
assert
isinstance
(
sign
,
...
...
@@ -335,10 +337,10 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
main_program
=
program
,
executor
=
exe
)
#
save to disk
data
=
module
.
SerializeToString
()
#
Serialize module_desc pb
module_pb
=
module_desc
.
SerializeToString
()
with
open
(
ModuleConfig
.
module_desc_path
(
module_dir
),
"wb"
)
as
f
:
f
.
write
(
data
)
f
.
write
(
module_pb
)
class
ModuleUtils
(
object
):
...
...
@@ -363,7 +365,3 @@ class ModuleUtils(object):
block
.
_remove_var
(
"fetch"
)
program
.
desc
.
flush
()
@
staticmethod
def
module_desc_path
(
module_dir
):
pass
paddle_hub/module_desc.proto
浏览文件 @
a8a0a8f1
...
...
@@ -18,9 +18,6 @@ option optimize_for = LITE_RUNTIME;
package
paddle_hub
;
message
Version
{
int64
version
=
1
;
}
// Feed Variable Description
message
FeedDesc
{
string
var_name
=
1
;
...
...
@@ -51,6 +48,6 @@ message ModuleDesc {
bool
contain_assets
=
4
;
Version
version
=
5
;
string
version
=
5
;
};
paddle_hub/module_desc_pb2.py
浏览文件 @
a8a0a8f1
...
...
@@ -17,46 +17,10 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddle_hub'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"\x1
a\n\x07
Version
\x12\x0f\n\x07
version
\x18\x01
\x01
(
\x03\"\x1c\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"\x1d\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xee\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08\x12
$
\n\x07
version
\x18\x05
\x01
(
\x0b\x32\x13
.paddle_hub.Version
\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"\x1
c\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"\x1d\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xd9\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08\x12\x0f\n\x07
version
\x18\x05
\x01
(
\t
\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_VERSION
=
_descriptor
.
Descriptor
(
name
=
'Version'
,
full_name
=
'paddle_hub.Version'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle_hub.Version.version'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
33
,
serialized_end
=
59
,
)
_FEEDDESC
=
_descriptor
.
Descriptor
(
name
=
'FeedDesc'
,
full_name
=
'paddle_hub.FeedDesc'
,
...
...
@@ -89,8 +53,8 @@ _FEEDDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
61
,
serialized_end
=
89
,
serialized_start
=
33
,
serialized_end
=
61
,
)
_FETCHDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -125,8 +89,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
91
,
serialized_end
=
120
,
serialized_start
=
63
,
serialized_end
=
92
,
)
_MODULEVAR
=
_descriptor
.
Descriptor
(
...
...
@@ -177,8 +141,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
122
,
serialized_end
=
217
,
serialized_start
=
94
,
serialized_end
=
189
,
)
_MODULEDESC_SIGN2VARENTRY
=
_descriptor
.
Descriptor
(
...
...
@@ -230,8 +194,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
3
88
,
serialized_end
=
4
58
,
serialized_start
=
3
39
,
serialized_end
=
4
09
,
)
_MODULEDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -310,11 +274,11 @@ _MODULEDESC = _descriptor.Descriptor(
full_name
=
'paddle_hub.ModuleDesc.version'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
)
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
...
...
@@ -332,8 +296,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
220
,
serialized_end
=
4
58
,
serialized_start
=
192
,
serialized_end
=
4
09
,
)
_MODULEVAR
.
fields_by_name
[
'fetch_desc'
].
message_type
=
_FETCHDESC
...
...
@@ -341,23 +305,11 @@ _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_MODULEDESC_SIGN2VARENTRY
.
fields_by_name
[
'value'
].
message_type
=
_MODULEVAR
_MODULEDESC_SIGN2VARENTRY
.
containing_type
=
_MODULEDESC
_MODULEDESC
.
fields_by_name
[
'sign2var'
].
message_type
=
_MODULEDESC_SIGN2VARENTRY
_MODULEDESC
.
fields_by_name
[
'version'
].
message_type
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'Version'
]
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'FeedDesc'
]
=
_FEEDDESC
DESCRIPTOR
.
message_types_by_name
[
'FetchDesc'
]
=
_FETCHDESC
DESCRIPTOR
.
message_types_by_name
[
'ModuleVar'
]
=
_MODULEVAR
DESCRIPTOR
.
message_types_by_name
[
'ModuleDesc'
]
=
_MODULEDESC
Version
=
_reflection
.
GeneratedProtocolMessageType
(
'Version'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VERSION
,
__module__
=
'module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.Version)
))
_sym_db
.
RegisterMessage
(
Version
)
FeedDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'FeedDesc'
,
(
_message
.
Message
,
),
...
...
tests/test_export_n_load_module.py
浏览文件 @
a8a0a8f1
...
...
@@ -165,11 +165,11 @@ def test_create_w2v_module(use_gpu=False):
def
test_load_w2v_module
(
use_gpu
=
False
):
saved_module_dir
=
"./tmp/word2vec_test_module"
w2v_module
=
hub
.
Module
(
module_dir
=
saved_module_dir
)
feed_list
,
fetch_list
,
program
=
w2v_module
(
feed_list
,
fetch_list
,
program
,
generator
=
w2v_module
(
sign_name
=
"default"
,
trainable
=
False
)
with
fluid
.
program_guard
(
main_program
=
program
):
with
fluid
.
unique_name
.
guard
(
generator
):
pred_prob
=
fetch_list
[
0
]
pred_word
=
fluid
.
layers
.
argmax
(
x
=
pred_prob
,
axis
=
1
)
# set place, executor, datafeeder
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录