Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
3ab7da0c
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
281
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3ab7da0c
编写于
1月 02, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix typo
上级
7ec25147
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
128 addition
and
82 deletion
+128
-82
Senta/sentiment_classify.py
Senta/sentiment_classify.py
+7
-29
paddle_hub/downloader.py
paddle_hub/downloader.py
+1
-1
paddle_hub/module.py
paddle_hub/module.py
+89
-32
paddle_hub/module_desc.proto
paddle_hub/module_desc.proto
+18
-7
paddle_hub/setup.cfg
paddle_hub/setup.cfg
+0
-2
requirements.txt
requirements.txt
+1
-0
setup.py
setup.py
+3
-4
test_export_n_load_module.py
test_export_n_load_module.py
+9
-7
未找到文件。
Senta/sentiment_classify.py
浏览文件 @
3ab7da0c
...
...
@@ -20,7 +20,6 @@ from nets import cnn_net
from
nets
import
lstm_net
from
nets
import
bilstm_net
from
nets
import
gru_net
logger
=
logging
.
getLogger
(
"paddle-fluid"
)
logger
.
setLevel
(
logging
.
INFO
)
...
...
@@ -93,28 +92,6 @@ def parse_args():
return
args
def
remove_feed_fetch_op
(
program
):
""" remove feed and fetch operator and variable for fine-tuning
"""
print
(
"remove feed fetch op"
)
block
=
program
.
global_block
()
need_to_remove_op_index
=
[]
for
i
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"feed"
or
op
.
type
==
"fetch"
:
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
block
.
_remove_op
(
index
)
block
.
_remove_var
(
"feed"
)
block
.
_remove_var
(
"fetch"
)
program
.
desc
.
flush
()
print
(
"********************************"
)
print
(
program
)
print
(
"********************************"
)
def
train_net
(
train_reader
,
word_dict
,
network_name
,
...
...
@@ -224,6 +201,7 @@ def retrain_net(train_reader,
fluid
.
framework
.
switch_main_program
(
module
.
get_inference_program
())
# remove feed fetch operator and variable
ModuleUtils
.
remove_feed_fetch_op
(
fluid
.
default_main_program
())
remove_feed_fetch_op
(
fluid
.
default_main_program
())
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
...
...
@@ -231,6 +209,9 @@ def retrain_net(train_reader,
#TODO(ZeyuChen): how to get output paramter according to proto config
emb
=
module
.
get_module_output
()
print
(
"adfjkajdlfjoqi jqiorejlmsfdlkjoi jqwierjoajsdklfjoi qjerijoajdfiqwjeor adfkalsf"
)
# # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
...
...
@@ -376,12 +357,9 @@ def main(args):
args
.
word_dict_path
,
args
.
batch_size
,
args
.
mode
)
# train_net(train_reader, word_dict, args.model_type, args.use_gpu,
# args.is_parallel, args.model_path, args.lr, args.batch_size,
# args.num_passes)
retrain_net
(
train_reader
,
word_dict
,
args
.
model_type
,
args
.
use_gpu
,
args
.
is_parallel
,
args
.
model_path
,
args
.
lr
,
args
.
batch_size
,
args
.
num_passes
)
train_net
(
train_reader
,
word_dict
,
args
.
model_type
,
args
.
use_gpu
,
args
.
is_parallel
,
args
.
model_path
,
args
.
lr
,
args
.
batch_size
,
args
.
num_passes
)
# eval mode
elif
args
.
mode
==
"eval"
:
...
...
paddle_hub/downloader.py
浏览文件 @
3ab7da0c
...
...
@@ -109,7 +109,7 @@ def download_and_uncompress(url, save_name=None):
for
file_name
in
file_names
:
tar
.
extract
(
file_name
,
dirname
)
return
module_dir
return
module_
name
,
module_
dir
class
TqdmProgress
(
tqdm
):
...
...
paddle_hub/module.py
浏览文件 @
3ab7da0c
...
...
@@ -19,15 +19,15 @@ from __future__ import print_function
import
paddle.fluid
as
fluid
import
numpy
as
np
import
tempfile
import
utils
import
os
import
module_desc_pb2
from
collections
import
defaultdict
from
downloader
import
download_and_uncompress
__all__
=
[
"Module"
,
"Module
Desc
"
]
__all__
=
[
"Module"
,
"Module
Config"
,
"ModuleUtils
"
]
DICT_NAME
=
"dict.txt"
ASSETS_
PATH
=
"assets"
ASSETS_
NAME
=
"assets"
def
mkdir
(
path
):
...
...
@@ -40,12 +40,13 @@ def mkdir(path):
class
Module
(
object
):
def
__init__
(
self
,
module_url
):
# donwload module
if
module_url
.
startswith
(
"http"
):
# if it's remote url links
if
module_url
.
startswith
(
"http"
):
# if it's remote url link, then download and uncompress it
module_dir
=
download_and_uncompress
(
module_url
)
module_
name
,
module_
dir
=
download_and_uncompress
(
module_url
)
else
:
# otherwise it's local path, no need to deal with it
module_dir
=
module_url
module_name
=
module_url
.
split
()[
-
1
]
# load paddle inference model
place
=
fluid
.
CPUPlace
()
...
...
@@ -62,9 +63,9 @@ class Module(object):
print
(
self
.
fetch_targets
)
# load assets
self
.
dict
=
defaultdict
(
int
)
self
.
dict
.
setdefault
(
0
)
self
.
_load_assets
(
module_dir
)
#
self.dict = defaultdict(int)
#
self.dict.setdefault(0)
#
self._load_assets(module_dir)
#TODO(ZeyuChen): Need add register more signature to execute different
# implmentation
...
...
@@ -92,6 +93,9 @@ class Module(object):
return
np_result
def
add_input_desc
(
var_name
):
pass
def
get_vars
(
self
):
return
self
.
inference_program
.
list_vars
()
...
...
@@ -144,23 +148,17 @@ class Module(object):
# load assets folder
def
_load_assets
(
self
,
module_dir
):
assets_dir
=
os
.
path
.
join
(
module_dir
,
ASSETS_
PATH
)
tokens
_path
=
os
.
path
.
join
(
assets_dir
,
DICT_NAME
)
assets_dir
=
os
.
path
.
join
(
module_dir
,
ASSETS_
NAME
)
dict
_path
=
os
.
path
.
join
(
assets_dir
,
DICT_NAME
)
word_id
=
0
with
open
(
tokens
_path
)
as
fi
:
with
open
(
dict
_path
)
as
fi
:
words
=
fi
.
readlines
()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for
line
in
fi
:
w
,
w_id
=
line
.
split
()
self
.
dict
[
w
]
=
int
(
w_id
)
# words = map(str.strip, words)
# for w in words:
# self.dict[w] = word_id
# word_id += 1
# print(w, word_id)
def
add_module_feed_list
(
self
,
feed_list
):
self
.
feed_list
=
feed_list
...
...
@@ -168,30 +166,89 @@ class Module(object):
self
.
output_list
=
output_list
class
ModuleDesc
(
object
):
def
__init__
(
self
):
pass
@
staticmethod
def
save_dict
(
path
,
word_dict
,
dict_name
):
""" Save dictionary for NLP module
class
ModuleConfig
(
object
):
def
__init__
(
self
,
module_dir
):
# generate model desc protobuf
self
.
module_dir
=
module_dir
self
.
desc
=
module_desc_pb3
.
ModuleDesc
()
self
.
desc
.
name
=
module_name
print
(
"desc.name="
,
self
.
desc
.
name
)
self
.
desc
.
signature
=
"default"
print
(
"desc.signature="
,
self
.
desc
.
signature
)
self
.
desc
.
contain_assets
=
True
print
(
"desc.signature="
,
self
.
desc
.
contain_assets
)
def
load
(
module_dir
):
"""load module config from module dir
"""
mkdir
(
path
)
with
open
(
os
.
path
.
join
(
path
,
dict_name
),
"w"
)
as
fo
:
print
(
"tokens.txt path"
,
os
.
path
.
join
(
path
,
DICT_NAME
))
#TODO(ZeyuChen): check module_desc.pb exsitance
with
open
(
pb_file_path
,
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
if
self
.
desc
.
contain_assets
:
# load assets
self
.
dict
=
defaultdict
(
int
)
self
.
dict
.
setdefault
(
0
)
assets_dir
=
os
.
path
.
join
(
self
.
module_dir
,
assets_dir
)
dict_path
=
os
.
path
.
join
(
assets_dir
,
DICT_NAME
)
word_id
=
0
with
open
(
dict_path
)
as
fi
:
words
=
fi
.
readlines
()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for
line
in
fi
:
w
,
w_id
=
line
.
split
()
self
.
dict
[
w
]
=
int
(
w_id
)
def
dump
():
# save module_desc.proto first
pb_path
=
os
.
path
.
join
(
self
.
module
,
"module_desc.pb"
)
with
open
(
pb_path
,
"wb"
)
as
fo
:
fo
.
write
(
self
.
desc
.
SerializeToString
())
# save assets/dictionary
assets_dir
=
os
.
path
.
join
(
self
.
module_dir
,
assets_dir
)
mkdir
(
assets_dir
)
with
open
(
os
.
path
.
join
(
assets_dir
,
DICT_NAME
),
"w"
)
as
fo
:
for
w
in
word_dict
:
w_id
=
word_dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
@
staticmethod
def
save_module_dict
(
module_path
,
word_dict
,
dict_name
=
DICT_NAME
):
def
save_dict
(
word_dict
,
dict_name
=
DICT_NAME
):
""" Save dictionary for NLP module
"""
assets_path
=
os
.
path
.
join
(
module_path
,
ASSETS_PATH
)
print
(
"save_module_dict"
,
assets_path
)
ModuleDesc
.
save_dict
(
assets_path
,
word_dict
,
dict_name
)
mkdir
(
path
)
with
open
(
os
.
path
.
join
(
self
.
module_dir
,
DICT_NAME
),
"w"
)
as
fo
:
for
w
in
word_dict
:
self
.
dict
[
w
]
=
word_dict
[
w
]
class
ModuleUtils
(
object
):
def
__init__
(
self
):
pass
@
staticmethod
def
remove_feed_fetch_op
(
program
):
""" remove feed and fetch operator and variable for fine-tuning
"""
print
(
"remove feed fetch op"
)
block
=
program
.
global_block
()
need_to_remove_op_index
=
[]
for
i
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"feed"
or
op
.
type
==
"fetch"
:
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
block
.
_remove_op
(
index
)
block
.
_remove_var
(
"feed"
)
block
.
_remove_var
(
"fetch"
)
program
.
desc
.
flush
()
print
(
"********************************"
)
print
(
program
)
print
(
"********************************"
)
if
__name__
==
"__main__"
:
module_link
=
"http://paddlehub.cdn.bcebos.com/word2vec/w2v_saved_inference_module.tar.gz"
...
...
paddle_hub/module_desc.proto
浏览文件 @
3ab7da0c
...
...
@@ -12,23 +12,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax
=
"proto3"
;
option
optimize_for
=
LITE_RUNTIME
;
package
paddle_hub
;
message
InputDesc
{
}
string
name
=
1
;
};
message
OutputDesc
{
bool
return_numpy
=
1
;
}
// A Hub Module is stored in a directory with a file 'paddlehub_module.pb'
string
name
=
1
;
};
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message
ModuleDesc
{
string
name
=
1
;
// PaddleHub module name
repeated
InputDesc
input_desc
=
2
;
repeated
OutputDesc
output_desc
=
3
;
string
signature
=
4
;
bool
return_numpy
=
5
;
repeated
string
input_signature
}
bool
contain_assets
=
6
;
}
;
paddle_hub/setup.cfg
已删除
100644 → 0
浏览文件 @
7ec25147
[metadata]
license_file = LICENSE
requirements.txt
0 → 100644
浏览文件 @
3ab7da0c
paddlepaddle
paddle_hub/
setup.py
→
setup.py
浏览文件 @
3ab7da0c
#
Copyright 2018 The TensorFlow Hub
Authors. All Rights Reserved.
#
Copyright (c) 2019 PaddlePaddle
Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
);
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
...
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Setup for pip package."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -29,7 +28,7 @@ REQUIRED_PACKAGES = [
]
setup
(
name
=
'paddle_hub'
,
name
=
'paddle_hub'
,
version
=
__version__
.
replace
(
'-'
,
''
),
description
=
(
'PaddleHub is a library to foster the publication, '
'discovery, and consumption of reusable parts of machine '
...
...
test_export_n_load_module.py
浏览文件 @
3ab7da0c
...
...
@@ -184,7 +184,7 @@ def train(use_cuda=False):
dictionary
.
append
(
w
)
# save word dict to assets folder
hub
.
Module
Desc
.
save_module_dict
(
hub
.
Module
Config
.
save_module_dict
(
module_path
=
saved_model_path
,
word_dict
=
dictionary
)
...
...
@@ -214,9 +214,9 @@ def test_save_module(use_cuda=False):
np_result
=
np
.
array
(
results
[
0
])
print
(
np_result
)
saved_module_
path
=
"./test/word2vec_inference_module"
saved_module_
dir
=
"./test/word2vec_inference_module"
fluid
.
io
.
save_inference_model
(
dirname
=
saved_module_
path
,
dirname
=
saved_module_
dir
,
feeded_var_names
=
[
"words"
],
target_vars
=
[
word_emb
],
executor
=
exe
)
...
...
@@ -227,17 +227,19 @@ def test_save_module(use_cuda=False):
w
=
w
.
decode
(
"ascii"
)
dictionary
.
append
(
w
)
# save word dict to assets folder
hub
.
ModuleDesc
.
save_module_dict
(
module_path
=
saved_module_path
,
word_dict
=
dictionary
)
config
=
hub
.
ModuleConfig
(
saved_module_dir
)
config
.
save_dict
(
word_dict
=
dictionary
)
config
.
dump
()
def
test_load_module
(
use_cuda
=
False
):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
saved_module_
path
=
"./test/word2vec_inference_module"
saved_module_
dir
=
"./test/word2vec_inference_module"
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
saved_module_
path
,
executor
=
exe
)
saved_module_
dir
,
executor
=
exe
)
# Sequence input in Paddle must be LOD Tensor, so we need to convert them inside Module
word_ids
=
[[
1
,
2
,
3
,
4
,
5
]]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录