Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
715fde80
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
715fde80
编写于
1月 15, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Merge module_creator.py to module.py
上级
b9f9ff25
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
317 addition
and
312 deletion
+317
-312
example/sentiment-classification/sentiment_classify.py
example/sentiment-classification/sentiment_classify.py
+81
-82
example/sentiment-classification/train.sh
example/sentiment-classification/train.sh
+1
-1
paddle_hub/__init__.py
paddle_hub/__init__.py
+15
-2
paddle_hub/config.py
paddle_hub/config.py
+0
-35
paddle_hub/module.py
paddle_hub/module.py
+140
-72
paddle_hub/module_creator.py
paddle_hub/module_creator.py
+0
-109
paddle_hub/module_desc.proto
paddle_hub/module_desc.proto
+5
-0
paddle_hub/module_desc_pb2.py
paddle_hub/module_desc_pb2.py
+75
-11
未找到文件。
example/sentiment-classification/sentiment_classify.py
浏览文件 @
715fde80
...
...
@@ -160,18 +160,19 @@ def train_net(train_reader,
hub
.
create_module
(
sign_arr
=
signature
,
program
=
fluid
.
default_main_program
(),
path
=
module_dir
)
def
retrain_net
(
train_reader
,
word_dict
,
network_name
,
use_gpu
,
parallel
,
save_dirname
,
lr
=
0.002
,
batch_size
=
128
,
pass_num
=
30
):
module_dir
=
module_dir
,
word_dict
=
word_dict
)
def
finetune_net
(
train_reader
,
word_dict
,
network_name
,
use_gpu
,
parallel
,
save_dirname
,
lr
=
0.002
,
batch_size
=
128
,
pass_num
=
30
):
"""
train network
"""
...
...
@@ -198,73 +199,71 @@ def retrain_net(train_reader,
module_dir
=
os
.
path
.
join
(
save_dirname
,
network_name
)
module
=
hub
.
Module
(
module_dir
=
module_dir
)
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
# use switch program to test fine-tuning
fluid
.
framework
.
switch_main_program
(
module
.
get_inference_program
())
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
data
=
module
.
get_feed_var_by_index
(
0
)
#TODO(ZeyuChen): how to get output paramter according to proto config
sent_emb
=
module
.
get_fetch_var_by_index
(
0
)
fc_1
=
fluid
.
layers
.
fc
(
input
=
sent_emb
,
size
=
hid_dim
,
act
=
"tanh"
,
name
=
"bow_fc1"
)
fc_2
=
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
hid_dim2
,
act
=
"tanh"
,
name
=
"bow_fc2"
)
# softmax layer
pred
=
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
class_dim
,
act
=
"softmax"
)
# print(fluid.default_main_program())
cost
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
cross_entropy
(
input
=
pred
,
label
=
label
))
acc
=
fluid
.
layers
.
accuracy
(
input
=
pred
,
label
=
label
)
with
open
(
"./prototxt/bow_net.forward.program_desc.prototxt"
,
"w"
)
as
fo
:
program_desc
=
str
(
fluid
.
default_main_program
())
fo
.
write
(
program_desc
)
# set optimizer
sgd_optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
lr
)
sgd_optimizer
.
minimize
(
cost
)
with
open
(
"./prototxt/bow_net.finetune.program_desc.prototxt"
,
"w"
)
as
fo
:
program_desc
=
str
(
fluid
.
default_main_program
())
fo
.
write
(
program_desc
)
# set place, executor, datafeeder
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
"words"
,
"label"
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
# start training...
for
pass_id
in
range
(
pass_num
):
data_size
,
data_count
,
total_acc
,
total_cost
=
0
,
0
,
0.0
,
0.0
for
batch
in
train_reader
():
avg_cost_np
,
avg_acc_np
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
feeder
.
feed
(
batch
),
fetch_list
=
[
cost
,
acc
],
return_numpy
=
True
)
data_size
=
len
(
batch
)
total_acc
+=
data_size
*
avg_acc_np
total_cost
+=
data_size
*
avg_cost_np
data_count
+=
data_size
avg_cost
=
total_cost
/
data_count
avg_acc
=
total_acc
/
data_count
print
(
"[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f"
%
(
pass_id
,
avg_acc
,
avg_cost
))
# save the model
module_dir
=
os
.
path
.
join
(
save_dirname
,
network_name
)
signature
=
hub
.
create_signature
(
"default"
,
inputs
=
[
data
],
outputs
=
[
sent_emb
])
hub
.
create_module
(
sign_arr
=
signature
,
program
=
fluid
.
default_main_program
(),
path
=
module_dir
)
feed_list
,
fetch_list
,
program
=
module
(
sign_name
=
"default"
,
trainable
=
True
)
with
fluid
.
program_guard
(
main_program
=
program
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
# data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config
sent_emb
=
fetch_list
[
0
]
# sent_emb = module.get_fetch_var_by_index(0)
fc_1
=
fluid
.
layers
.
fc
(
input
=
sent_emb
,
size
=
hid_dim
,
act
=
"tanh"
,
name
=
"bow_fc1"
)
fc_2
=
fluid
.
layers
.
fc
(
input
=
fc_1
,
size
=
hid_dim2
,
act
=
"tanh"
,
name
=
"bow_fc2"
)
# softmax layer
pred
=
fluid
.
layers
.
fc
(
input
=
[
fc_2
],
size
=
class_dim
,
act
=
"softmax"
)
# print(fluid.default_main_program())
cost
=
fluid
.
layers
.
mean
(
fluid
.
layers
.
cross_entropy
(
input
=
pred
,
label
=
label
))
acc
=
fluid
.
layers
.
accuracy
(
input
=
pred
,
label
=
label
)
with
open
(
"./prototxt/bow_net.forward.program_desc.prototxt"
,
"w"
)
as
fo
:
program_desc
=
str
(
fluid
.
default_main_program
())
fo
.
write
(
program_desc
)
# set optimizer
sgd_optimizer
=
fluid
.
optimizer
.
Adagrad
(
learning_rate
=
lr
)
sgd_optimizer
.
minimize
(
cost
)
with
open
(
"./prototxt/bow_net.finetune.program_desc.prototxt"
,
"w"
)
as
fo
:
program_desc
=
str
(
fluid
.
default_main_program
())
fo
.
write
(
program_desc
)
# set place, executor, datafeeder
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
"words"
,
"label"
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
# start training...
for
pass_id
in
range
(
pass_num
):
data_size
,
data_count
,
total_acc
,
total_cost
=
0
,
0
,
0.0
,
0.0
for
batch
in
train_reader
():
avg_cost_np
,
avg_acc_np
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
feeder
.
feed
(
batch
),
fetch_list
=
[
cost
,
acc
],
return_numpy
=
True
)
data_size
=
len
(
batch
)
total_acc
+=
data_size
*
avg_acc_np
total_cost
+=
data_size
*
avg_cost_np
data_count
+=
data_size
avg_cost
=
total_cost
/
data_count
avg_acc
=
total_acc
/
data_count
print
(
"[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f"
%
(
pass_id
,
avg_acc
,
avg_cost
))
# # save the model
# module_dir = os.path.join(save_dirname, network_name)
# signature = hub.create_signature(
# "default", inputs=[data], outputs=[sent_emb])
# hub.create_module(
# sign_arr=signature,
# program=fluid.default_main_program(),
# path=module_dir)
def
eval_net
(
test_reader
,
use_gpu
,
model_path
=
None
):
...
...
@@ -367,9 +366,9 @@ def main(args):
args
.
word_dict_path
,
args
.
batch_size
,
args
.
mode
)
retrain
_net
(
train_reader
,
word_dict
,
args
.
model_type
,
args
.
use_gpu
,
args
.
is_parallel
,
args
.
model_path
,
args
.
lr
,
args
.
batch_size
,
args
.
num_passes
)
finetune
_net
(
train_reader
,
word_dict
,
args
.
model_type
,
args
.
use_gpu
,
args
.
is_parallel
,
args
.
model_path
,
args
.
lr
,
args
.
batch_size
,
args
.
num_passes
)
# eval mode
elif
args
.
mode
==
"eval"
:
# prepare_data to get word_dict, test_reader
...
...
example/sentiment-classification/train.sh
浏览文件 @
715fde80
python sentiment_classify.py
--train_data_path
./data/train_data/corpus.train
--word_dict_path
./data/train.vocab
--mode
train
--model_path
./models
python sentiment_classify.py
--train_data_path
./data/train_data/corpus.train
--word_dict_path
./data/train.vocab
--mode
train
--model_path
./models
--num_passes
=
1
paddle_hub/__init__.py
浏览文件 @
715fde80
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -7,7 +21,6 @@ import paddle.fluid as fluid
from
paddle_hub.module
import
Module
from
paddle_hub.module
import
ModuleConfig
from
paddle_hub.module
import
ModuleUtils
from
paddle_hub.module
import
create_module
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub.signature
import
create_signature
from
paddle_hub.module_creator
import
create_module
from
paddle_hub.config
import
RunConfig
,
ParamTrainConfig
paddle_hub/config.py
已删除
100644 → 0
浏览文件 @
b9f9ff25
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
enum
import
Enum
,
unique
@
unique
class
ParamTrainConfig
(
Enum
):
PARAM_TRAIN_DEFAULT
=
0
PARAM_TRAIN_ALL
=
1
PARAM_TRAIN_NONE
=
2
class
RunConfig
:
def
__init__
(
self
,
param_train_config
=
None
):
assert
(
not
param_train_config
or
param_train_config
in
ParamTrainConfig
),
"train config should be value of %s"
%
ParamTrainConfig
if
not
param_train_config
:
param_train_config
=
ParamTrainConfig
.
PARAM_TRAIN_DEFAULT
self
.
param_train_config
=
param_train_config
paddle_hub/module.py
浏览文件 @
715fde80
...
...
@@ -27,11 +27,20 @@ import pickle
from
collections
import
defaultdict
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub
import
module_desc_pb2
from
paddle_hub.config
import
RunConfig
,
ParamTrainConfig
from
paddle_hub.signature
import
Signature
from
paddle_hub.utils
import
to_list
__all__
=
[
"Module"
,
"ModuleConfig"
,
"ModuleUtils"
]
DICT_NAME
=
"dict.txt"
ASSETS_NAME
=
"assets"
# paddle hub module dir name
ASSETS_DIRNAME
=
"assets"
META_DIRNAME
=
"meta"
MODEL_DIRNAME
=
"model"
# paddle hub module serialze file name
DICT_FILENAME
=
"vocab.txt"
PARAM_FILENAME
=
"param.pkl"
MODULE_DESC_PBNAME
=
"module_desc.pb"
GENERATOR_FILENAME
=
"unique_name_generator.pkl"
def
mkdir
(
path
):
...
...
@@ -67,8 +76,7 @@ class Module(object):
# load paddle inference model
place
=
fluid
.
CPUPlace
()
model_dir
=
os
.
path
.
join
(
self
.
module_dir
,
"model"
)
print
(
"model_dir"
,
model_dir
)
model_dir
=
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
self
.
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
[
self
.
inference_program
,
self
.
feed_target_names
,
self
.
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
...
...
@@ -91,14 +99,15 @@ class Module(object):
self
.
_process_uqn
()
def
_process_uqn
(
self
):
filepath
=
os
.
path
.
join
(
self
.
module_dir
,
"uqn.pkl"
)
with
open
(
filepath
,
"rb"
)
as
file
:
fluid
.
unique_name
.
switch
(
pickle
.
load
(
fi
le
))
name_generator_path
=
ModuleConfig
.
name_generator_path
(
self
.
module_dir
)
with
open
(
name_generator_path
,
"rb"
)
as
fi
:
fluid
.
unique_name
.
switch
(
pickle
.
load
(
fi
))
def
_process_parameter
(
self
):
global_block
=
self
.
inference_program
.
global_block
()
filepath
=
os
.
path
.
join
(
self
.
module_dir
,
"param.pkl"
)
with
open
(
filepath
,
"rb"
)
as
file
:
param_path
=
ModuleConfig
.
meta_param_path
(
self
.
module_dir
)
with
open
(
param_path
,
"rb"
)
as
file
:
param_arr
=
pickle
.
load
(
file
)
for
param
in
param_arr
:
if
(
param
[
'name'
]
not
in
global_block
.
vars
):
...
...
@@ -124,7 +133,7 @@ class Module(object):
return
feed_dict
def
__call__
(
self
,
sign_name
=
"default"
,
run_config
=
Non
e
):
def
__call__
(
self
,
sign_name
=
"default"
,
trainable
=
Fals
e
):
""" Call default signature and return results
"""
...
...
@@ -137,16 +146,10 @@ class Module(object):
if
op
.
has_attr
(
"is_test"
):
op
.
_set_attr
(
"is_test"
,
is_test
)
if
not
run_config
:
run_config
=
RunConfig
()
program
=
self
.
get_inference_program
().
clone
()
_process_op_attr
(
program
=
program
,
is_test
=
False
)
if
run_config
.
param_train_config
==
ParamTrainConfig
.
PARAM_TRAIN_ALL
:
_set_param_trainable
(
program
=
program
,
trainable
=
True
)
elif
run_config
.
param_train_config
==
ParamTrainConfig
.
PARAM_TRAIN_ALL
:
_set_param_trainable
(
program
=
program
,
trainable
=
False
)
_set_param_trainable
(
program
=
program
,
trainable
=
trainable
)
return
self
.
feed_target_names
,
self
.
fetch_targets
,
program
...
...
@@ -282,79 +285,30 @@ class ModuleConfig(object):
Load module config from module directory.
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
pb_path
=
os
.
path
.
join
(
self
.
module_dir
,
"module_desc.pb"
)
with
open
(
pb_path
,
"rb"
)
as
fi
:
with
open
(
ModuleConfig
.
module_desc_path
(
self
.
module_dir
),
"rb"
)
as
fi
:
self
.
desc
.
ParseFromString
(
fi
.
read
())
# print("self.desc.sign2var",
# self.desc.sign2var["default"].feed_desc[0].var_name)
if
self
.
desc
.
contain_assets
:
# load assets
assets_dir
=
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_NAME
)
dict_path
=
os
.
path
.
join
(
assets_dir
,
DICT_NAME
)
word_id
=
0
with
open
(
dict_path
)
as
fi
:
with
open
(
ModuleConfig
.
assets_dict_path
(
self
.
module_dir
))
as
fi
:
words
=
fi
.
readlines
()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for
line
in
fi
:
w
,
w_id
=
line
.
split
()
self
.
dict
[
w
]
=
int
(
w_id
)
def
dump
(
self
):
""" Save Module configure file to disk.
"""
pb_path
=
os
.
path
.
join
(
self
.
module_dir
,
"module_desc.pb"
)
with
open
(
pb_path
,
"wb"
)
as
fo
:
fo
.
write
(
self
.
desc
.
SerializeToString
())
# save assets/dictionary
assets_dir
=
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_NAME
)
mkdir
(
assets_dir
)
with
open
(
os
.
path
.
join
(
assets_dir
,
DICT_NAME
),
"w"
)
as
fo
:
for
w
in
self
.
dict
:
w_id
=
self
.
dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
def
return_numpy
(
self
):
"""Return numpy or not according to the proto config.
"""
return
self
.
desc
.
return_numpy
def
save_dict
(
self
,
word_dict
,
dict_name
=
DICT_NAME
):
def
save_dict
(
self
,
word_dict
,
dict_name
=
DICT_
FILE
NAME
):
""" Save dictionary for NLP module
"""
for
w
in
word_dict
:
self
.
dict
[
w
]
=
word_dict
[
w
]
def
register_feed_signature
(
self
,
feed_desc
,
sign_name
=
"default"
):
""" Register feed signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for
k
in
feed_desc
:
feed
=
self
.
desc
.
sign2var
[
sign_name
].
feed_desc
.
add
()
feed
.
key
=
k
feed
.
var_name
=
feed_desc
[
k
]
def
register_fetch_signature
(
self
,
fetch_desc
,
sign_name
=
"default"
):
""" Register fetch signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for
k
in
fetch_desc
:
fetch
=
self
.
desc
.
sign2var
[
sign_name
].
fetch_desc
.
add
()
fetch
.
key
=
k
fetch
.
var_name
=
fetch_desc
[
k
]
def
feed_var_names
(
self
,
sign_name
=
"default"
):
return
self
.
desc
.
sign2var
[
sign_name
].
feed_desc
...
...
@@ -377,6 +331,119 @@ class ModuleConfig(object):
return
desc
.
var_name
raise
Exception
(
"fetch variable {} not found"
.
format
(
key
))
@
staticmethod
def
module_desc_path
(
module_dir
):
return
os
.
path
.
join
(
module_dir
,
MODULE_DESC_PBNAME
)
@
staticmethod
def
name_generator_path
(
module_dir
):
meta_path
=
os
.
path
.
join
(
module_dir
,
META_DIRNAME
)
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
GENERATOR_FILENAME
)
@
staticmethod
def
assets_dict_path
(
module_dir
):
assets_path
=
os
.
path
.
join
(
module_dir
,
ASSETS_DIRNAME
)
mkdir
(
assets_path
)
return
os
.
path
.
join
(
assets_path
,
DICT_FILENAME
)
@
staticmethod
def
meta_param_path
(
module_dir
):
meta_path
=
os
.
path
.
join
(
module_dir
,
META_DIRNAME
)
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
PARAM_FILENAME
)
@
staticmethod
def
meta_name_generator_path
(
module_dir
):
meta_path
=
os
.
path
.
join
(
module_dir
,
META_DIRNAME
)
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
GENERATOR_FILENAME
)
def
create_module
(
sign_arr
,
program
,
module_dir
=
None
,
word_dict
=
None
):
""" Create a module from main program
"""
assert
isinstance
(
program
,
fluid
.
Program
),
"program should be instance of fluid.Program"
assert
sign_arr
,
"signature array should not be None"
if
module_dir
is
None
:
module_dir
=
os
.
path
.
join
(
"."
,
"hub_module"
)
# create module path for saving
mkdir
(
module_dir
)
module
=
module_desc_pb2
.
ModuleDesc
()
program
=
program
.
clone
()
if
word_dict
is
None
:
module
.
contain_assets
=
False
else
:
module
.
contain_assets
=
True
with
open
(
ModuleConfig
.
assets_dict_path
(
module_dir
),
"w"
)
as
fo
:
for
w
in
word_dict
:
w_id
=
word_dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
# save the unique name generator object
generator
=
fluid
.
unique_name
.
generator
with
open
(
ModuleConfig
.
name_generator_path
(
module_dir
),
"wb"
)
as
fo
:
pickle
.
dump
(
generator
,
fo
)
# save fluid Parameter
param_arr
=
[]
for
param
in
program
.
global_block
().
iter_parameters
():
param_info
=
{
'name'
:
param
.
name
,
'regularizer'
:
param
.
regularizer
,
'gradient_clip_attr'
:
param
.
gradient_clip_attr
,
'trainable'
:
param
.
trainable
,
'optimize_attr'
:
param
.
optimize_attr
,
'do_model_average'
:
param
.
do_model_average
}
param_arr
.
append
(
param_info
)
with
open
(
ModuleConfig
.
meta_param_path
(
module_dir
),
"wb"
)
as
fo
:
pickle
.
dump
(
param_arr
,
fo
)
# save signarture info
sign_map
=
module
.
sign2var
sign_arr
=
to_list
(
sign_arr
)
for
sign
in
sign_arr
:
assert
isinstance
(
sign
,
Signature
),
"sign_arr should be list of Signature"
if
sign
.
get_name
()
in
sign_map
:
raise
"Error! sign_arr contains repeat signatrue %s"
%
sign
var
=
sign_map
[
sign
.
get_name
()]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
for
input
in
sign
.
get_inputs
():
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
input
.
name
for
output
in
sign
.
get_outputs
():
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
output
.
name
# save inference program
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
model_dir
=
os
.
path
.
join
(
module_dir
,
"model"
)
mkdir
(
model_dir
)
# TODO(ZeyuChen): here only deal with one signature
first_sign
=
sign_arr
[
0
]
fluid
.
io
.
save_inference_model
(
model_dir
,
feeded_var_names
=
[
var
.
name
for
var
in
first_sign
.
get_inputs
()],
target_vars
=
first_sign
.
get_outputs
(),
main_program
=
program
,
executor
=
exe
)
# save to disk
data
=
module
.
SerializeToString
()
with
open
(
ModuleConfig
.
module_desc_path
(
module_dir
),
"wb"
)
as
f
:
f
.
write
(
data
)
class
ModuleUtils
(
object
):
def
__init__
(
self
):
...
...
@@ -400,6 +467,7 @@ class ModuleUtils(object):
block
.
_remove_var
(
"fetch"
)
program
.
desc
.
flush
()
# print("********************************")
# print(program)
# print("********************************")
@
staticmethod
def
module_desc_path
(
module_dir
):
pass
paddle_hub/module_creator.py
已删除
100644 → 0
浏览文件 @
b9f9ff25
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle_hub.module_desc_pb2
as
modulepb
import
paddle.fluid
as
fluid
from
paddle_hub.utils
import
to_list
from
paddle_hub.signature
import
Signature
from
paddle_hub.module
import
mkdir
import
os
import
pickle
def
create_module
(
sign_arr
,
program
,
path
=
None
,
assets
=
None
):
assert
isinstance
(
program
,
fluid
.
Program
),
"program should be instance of fluid.Program"
assert
sign_arr
,
"signarture array should not be None"
if
not
path
:
path
=
os
.
path
.
join
(
"."
,
"hub_module"
)
# create module path for saving
mkdir
(
path
)
module
=
modulepb
.
ModuleDesc
()
program
=
program
.
clone
()
# TODO(wuzewu): save assets data
if
not
assets
:
module
.
contain_assets
=
False
else
:
module
.
contain_assets
=
True
os
.
makedirs
(
os
.
path
.
join
(
path
,
"assets"
))
# save the unique name object
generator
=
fluid
.
unique_name
.
generator
pklname
=
os
.
path
.
join
(
path
,
"uqn.pkl"
)
with
open
(
pklname
,
"wb"
)
as
file
:
pickle
.
dump
(
generator
,
file
)
# save fluid Parameter
param_arr
=
[]
for
param
in
program
.
global_block
().
iter_parameters
():
param_info
=
{
'name'
:
param
.
name
,
'regularizer'
:
param
.
regularizer
,
'gradient_clip_attr'
:
param
.
gradient_clip_attr
,
'trainable'
:
param
.
trainable
,
'optimize_attr'
:
param
.
optimize_attr
,
'do_model_average'
:
param
.
do_model_average
}
param_arr
.
append
(
param_info
)
pklname
=
os
.
path
.
join
(
path
,
"param.pkl"
)
with
open
(
pklname
,
"wb"
)
as
file
:
pickle
.
dump
(
param_arr
,
file
)
# save signarture info
sign_map
=
module
.
sign2var
sign_arr
=
to_list
(
sign_arr
)
for
sign
in
sign_arr
:
assert
isinstance
(
sign
,
Signature
),
"sign_arr should be list of Signature"
if
sign
.
get_name
()
in
sign_map
:
raise
"Error! sign_arr contains repeat signatrue %s"
%
sign
var
=
sign_map
[
sign
.
get_name
()]
feed_desc
=
var
.
feed_desc
fetch_desc
=
var
.
fetch_desc
for
input
in
sign
.
get_inputs
():
feed_var
=
feed_desc
.
add
()
feed_var
.
var_name
=
input
.
name
for
output
in
sign
.
get_outputs
():
fetch_var
=
fetch_desc
.
add
()
fetch_var
.
var_name
=
output
.
name
# save inference program
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
model_path
=
os
.
path
.
join
(
path
,
"model"
)
mkdir
(
model_path
)
first_sign
=
sign_arr
[
0
]
fluid
.
io
.
save_inference_model
(
model_path
,
feeded_var_names
=
[
var
.
name
for
var
in
first_sign
.
get_inputs
()],
target_vars
=
first_sign
.
get_outputs
(),
main_program
=
program
,
executor
=
exe
)
# save to disk
data
=
module
.
SerializeToString
()
metafile
=
os
.
path
.
join
(
path
,
"module_desc.pb"
)
with
open
(
metafile
,
"wb"
)
as
f
:
f
.
write
(
data
)
paddle_hub/module_desc.proto
浏览文件 @
715fde80
...
...
@@ -18,6 +18,9 @@ option optimize_for = LITE_RUNTIME;
package
paddle_hub
;
message
Version
{
int64
version
=
1
;
}
// Feed Variable Description
message
FeedDesc
{
string
var_name
=
1
;
...
...
@@ -47,5 +50,7 @@ message ModuleDesc {
bool
return_numpy
=
3
;
bool
contain_assets
=
4
;
Version
version
=
5
;
};
paddle_hub/module_desc_pb2.py
浏览文件 @
715fde80
...
...
@@ -17,10 +17,46 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package
=
'paddle_hub'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"\x1
c\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"\x1d\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xc8\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08
\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
'
\n\x11
module_desc.proto
\x12\n
paddle_hub
\"\x1
a\n\x07
Version
\x12\x0f\n\x07
version
\x18\x01
\x01
(
\x03\"\x1c\n\x08\x46\x65\x65\x64\x44\x65
sc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"\x1d\n\t
FetchDesc
\x12\x10\n\x08
var_name
\x18\x01
\x01
(
\t\"
_
\n\t
ModuleVar
\x12
)
\n\n
fetch_desc
\x18\x01
\x03
(
\x0b\x32\x15
.paddle_hub.FetchDesc
\x12\'\n\t
feed_desc
\x18\x02
\x03
(
\x0b\x32\x14
.paddle_hub.FeedDesc
\"\xee\x01\n\n
ModuleDesc
\x12\x0c\n\x04
name
\x18\x01
\x01
(
\t\x12\x36\n\x08
sign2var
\x18\x02
\x03
(
\x0b\x32
$.paddle_hub.ModuleDesc.Sign2varEntry
\x12\x14\n\x0c
return_numpy
\x18\x03
\x01
(
\x08\x12\x16\n\x0e\x63
ontain_assets
\x18\x04
\x01
(
\x08\x12
$
\n\x07
version
\x18\x05
\x01
(
\x0b\x32\x13
.paddle_hub.Version
\x1a\x46\n\r
Sign2varEntry
\x12\x0b\n\x03
key
\x18\x01
\x01
(
\t\x12
$
\n\x05
value
\x18\x02
\x01
(
\x0b\x32\x15
.paddle_hub.ModuleVar:
\x02\x38\x01\x42\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_VERSION
=
_descriptor
.
Descriptor
(
name
=
'Version'
,
full_name
=
'paddle_hub.Version'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle_hub.Version.version'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
33
,
serialized_end
=
59
,
)
_FEEDDESC
=
_descriptor
.
Descriptor
(
name
=
'FeedDesc'
,
full_name
=
'paddle_hub.FeedDesc'
,
...
...
@@ -53,8 +89,8 @@ _FEEDDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
33
,
serialized_end
=
61
,
serialized_start
=
61
,
serialized_end
=
89
,
)
_FETCHDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -89,8 +125,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
63
,
serialized_end
=
92
,
serialized_start
=
91
,
serialized_end
=
120
,
)
_MODULEVAR
=
_descriptor
.
Descriptor
(
...
...
@@ -141,8 +177,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
94
,
serialized_end
=
189
,
serialized_start
=
122
,
serialized_end
=
217
,
)
_MODULEDESC_SIGN2VARENTRY
=
_descriptor
.
Descriptor
(
...
...
@@ -194,8 +230,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
3
22
,
serialized_end
=
392
,
serialized_start
=
3
88
,
serialized_end
=
458
,
)
_MODULEDESC
=
_descriptor
.
Descriptor
(
...
...
@@ -269,6 +305,22 @@ _MODULEDESC = _descriptor.Descriptor(
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle_hub.ModuleDesc.version'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[
...
...
@@ -280,8 +332,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
192
,
serialized_end
=
392
,
serialized_start
=
220
,
serialized_end
=
458
,
)
_MODULEVAR
.
fields_by_name
[
'fetch_desc'
].
message_type
=
_FETCHDESC
...
...
@@ -289,11 +341,23 @@ _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_MODULEDESC_SIGN2VARENTRY
.
fields_by_name
[
'value'
].
message_type
=
_MODULEVAR
_MODULEDESC_SIGN2VARENTRY
.
containing_type
=
_MODULEDESC
_MODULEDESC
.
fields_by_name
[
'sign2var'
].
message_type
=
_MODULEDESC_SIGN2VARENTRY
_MODULEDESC
.
fields_by_name
[
'version'
].
message_type
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'Version'
]
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'FeedDesc'
]
=
_FEEDDESC
DESCRIPTOR
.
message_types_by_name
[
'FetchDesc'
]
=
_FETCHDESC
DESCRIPTOR
.
message_types_by_name
[
'ModuleVar'
]
=
_MODULEVAR
DESCRIPTOR
.
message_types_by_name
[
'ModuleDesc'
]
=
_MODULEDESC
Version
=
_reflection
.
GeneratedProtocolMessageType
(
'Version'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VERSION
,
__module__
=
'module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.Version)
))
_sym_db
.
RegisterMessage
(
Version
)
FeedDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'FeedDesc'
,
(
_message
.
Message
,
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录