提交 715fde80 编写于 作者: Z Zeyu Chen

Merge module_creator.py to module.py

上级 b9f9ff25
...@@ -160,18 +160,19 @@ def train_net(train_reader, ...@@ -160,18 +160,19 @@ def train_net(train_reader,
hub.create_module( hub.create_module(
sign_arr=signature, sign_arr=signature,
program=fluid.default_main_program(), program=fluid.default_main_program(),
path=module_dir) module_dir=module_dir,
word_dict=word_dict)
def retrain_net(train_reader,
word_dict, def finetune_net(train_reader,
network_name, word_dict,
use_gpu, network_name,
parallel, use_gpu,
save_dirname, parallel,
lr=0.002, save_dirname,
batch_size=128, lr=0.002,
pass_num=30): batch_size=128,
pass_num=30):
""" """
train network train network
""" """
...@@ -198,73 +199,71 @@ def retrain_net(train_reader, ...@@ -198,73 +199,71 @@ def retrain_net(train_reader,
module_dir = os.path.join(save_dirname, network_name) module_dir = os.path.join(save_dirname, network_name)
module = hub.Module(module_dir=module_dir) module = hub.Module(module_dir=module_dir)
main_program = fluid.Program() feed_list, fetch_list, program = module(sign_name="default", trainable=True)
startup_program = fluid.Program() with fluid.program_guard(main_program=program):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# use switch program to test fine-tuning # data = module.get_feed_var_by_index(0)
fluid.framework.switch_main_program(module.get_inference_program()) #TODO(ZeyuChen): how to get output paramter according to proto config
sent_emb = fetch_list[0]
label = fluid.layers.data(name="label", shape=[1], dtype="int64") # sent_emb = module.get_fetch_var_by_index(0)
data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config fc_1 = fluid.layers.fc(
sent_emb = module.get_fetch_var_by_index(0) input=sent_emb, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc(
fc_1 = fluid.layers.fc( input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
input=sent_emb, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc( # softmax layer
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2") pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
# print(fluid.default_main_program())
# softmax layer cost = fluid.layers.mean(
pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") fluid.layers.cross_entropy(input=pred, label=label))
# print(fluid.default_main_program()) acc = fluid.layers.accuracy(input=pred, label=label)
cost = fluid.layers.mean(
fluid.layers.cross_entropy(input=pred, label=label)) with open("./prototxt/bow_net.forward.program_desc.prototxt",
acc = fluid.layers.accuracy(input=pred, label=label) "w") as fo:
program_desc = str(fluid.default_main_program())
with open("./prototxt/bow_net.forward.program_desc.prototxt", "w") as fo: fo.write(program_desc)
program_desc = str(fluid.default_main_program()) # set optimizer
fo.write(program_desc) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
# set optimizer sgd_optimizer.minimize(cost)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost) with open("./prototxt/bow_net.finetune.program_desc.prototxt",
"w") as fo:
with open("./prototxt/bow_net.finetune.program_desc.prototxt", "w") as fo: program_desc = str(fluid.default_main_program())
program_desc = str(fluid.default_main_program()) fo.write(program_desc)
fo.write(program_desc)
# set place, executor, datafeeder
# set place, executor, datafeeder place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() exe = fluid.Executor(place)
exe = fluid.Executor(place) feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place) exe.run(fluid.default_startup_program())
exe.run(fluid.default_startup_program()) # start training...
# start training...
for pass_id in range(pass_num):
for pass_id in range(pass_num): data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0 for batch in train_reader():
for batch in train_reader(): avg_cost_np, avg_acc_np = exe.run(
avg_cost_np, avg_acc_np = exe.run( fluid.default_main_program(),
fluid.default_main_program(), feed=feeder.feed(batch),
feed=feeder.feed(batch), fetch_list=[cost, acc],
fetch_list=[cost, acc], return_numpy=True)
return_numpy=True) data_size = len(batch)
data_size = len(batch) total_acc += data_size * avg_acc_np
total_acc += data_size * avg_acc_np total_cost += data_size * avg_cost_np
total_cost += data_size * avg_cost_np data_count += data_size
data_count += data_size avg_cost = total_cost / data_count
avg_cost = total_cost / data_count avg_acc = total_acc / data_count
avg_acc = total_acc / data_count print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" % (pass_id, avg_acc, avg_cost))
(pass_id, avg_acc, avg_cost))
# # save the model
# save the model # module_dir = os.path.join(save_dirname, network_name)
# signature = hub.create_signature(
module_dir = os.path.join(save_dirname, network_name) # "default", inputs=[data], outputs=[sent_emb])
signature = hub.create_signature( # hub.create_module(
"default", inputs=[data], outputs=[sent_emb]) # sign_arr=signature,
hub.create_module( # program=fluid.default_main_program(),
sign_arr=signature, # path=module_dir)
program=fluid.default_main_program(),
path=module_dir)
def eval_net(test_reader, use_gpu, model_path=None): def eval_net(test_reader, use_gpu, model_path=None):
...@@ -367,9 +366,9 @@ def main(args): ...@@ -367,9 +366,9 @@ def main(args):
args.word_dict_path, args.word_dict_path,
args.batch_size, args.mode) args.batch_size, args.mode)
retrain_net(train_reader, word_dict, args.model_type, args.use_gpu, finetune_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size, args.is_parallel, args.model_path, args.lr,
args.num_passes) args.batch_size, args.num_passes)
# eval mode # eval mode
elif args.mode == "eval": elif args.mode == "eval":
# prepare_data to get word_dict, test_reader # prepare_data to get word_dict, test_reader
......
python sentiment_classify.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models python sentiment_classify.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models --num_passes=1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -7,7 +21,6 @@ import paddle.fluid as fluid ...@@ -7,7 +21,6 @@ import paddle.fluid as fluid
from paddle_hub.module import Module from paddle_hub.module import Module
from paddle_hub.module import ModuleConfig from paddle_hub.module import ModuleConfig
from paddle_hub.module import ModuleUtils from paddle_hub.module import ModuleUtils
from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature from paddle_hub.signature import create_signature
from paddle_hub.module_creator import create_module
from paddle_hub.config import RunConfig, ParamTrainConfig
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from enum import Enum, unique
@unique
class ParamTrainConfig(Enum):
PARAM_TRAIN_DEFAULT = 0
PARAM_TRAIN_ALL = 1
PARAM_TRAIN_NONE = 2
class RunConfig:
def __init__(self, param_train_config=None):
assert (not param_train_config or param_train_config in ParamTrainConfig
), "train config should be value of %s" % ParamTrainConfig
if not param_train_config:
param_train_config = ParamTrainConfig.PARAM_TRAIN_DEFAULT
self.param_train_config = param_train_config
...@@ -27,11 +27,20 @@ import pickle ...@@ -27,11 +27,20 @@ import pickle
from collections import defaultdict from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2 from paddle_hub import module_desc_pb2
from paddle_hub.config import RunConfig, ParamTrainConfig from paddle_hub.signature import Signature
from paddle_hub.utils import to_list
__all__ = ["Module", "ModuleConfig", "ModuleUtils"] __all__ = ["Module", "ModuleConfig", "ModuleUtils"]
DICT_NAME = "dict.txt"
ASSETS_NAME = "assets" # paddle hub module dir name
ASSETS_DIRNAME = "assets"
META_DIRNAME = "meta"
MODEL_DIRNAME = "model"
# paddle hub module serialze file name
DICT_FILENAME = "vocab.txt"
PARAM_FILENAME = "param.pkl"
MODULE_DESC_PBNAME = "module_desc.pb"
GENERATOR_FILENAME = "unique_name_generator.pkl"
def mkdir(path): def mkdir(path):
...@@ -67,8 +76,7 @@ class Module(object): ...@@ -67,8 +76,7 @@ class Module(object):
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, "model") model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
print("model_dir", model_dir)
self.exe = fluid.Executor(fluid.CPUPlace()) self.exe = fluid.Executor(fluid.CPUPlace())
[self.inference_program, self.feed_target_names, [self.inference_program, self.feed_target_names,
self.fetch_targets] = fluid.io.load_inference_model( self.fetch_targets] = fluid.io.load_inference_model(
...@@ -91,14 +99,15 @@ class Module(object): ...@@ -91,14 +99,15 @@ class Module(object):
self._process_uqn() self._process_uqn()
def _process_uqn(self): def _process_uqn(self):
filepath = os.path.join(self.module_dir, "uqn.pkl") name_generator_path = ModuleConfig.name_generator_path(self.module_dir)
with open(filepath, "rb") as file: with open(name_generator_path, "rb") as fi:
fluid.unique_name.switch(pickle.load(file)) fluid.unique_name.switch(pickle.load(fi))
def _process_parameter(self): def _process_parameter(self):
global_block = self.inference_program.global_block() global_block = self.inference_program.global_block()
filepath = os.path.join(self.module_dir, "param.pkl") filepath = os.path.join(self.module_dir, "param.pkl")
with open(filepath, "rb") as file: param_path = ModuleConfig.meta_param_path(self.module_dir)
with open(param_path, "rb") as file:
param_arr = pickle.load(file) param_arr = pickle.load(file)
for param in param_arr: for param in param_arr:
if (param['name'] not in global_block.vars): if (param['name'] not in global_block.vars):
...@@ -124,7 +133,7 @@ class Module(object): ...@@ -124,7 +133,7 @@ class Module(object):
return feed_dict return feed_dict
def __call__(self, sign_name="default", run_config=None): def __call__(self, sign_name="default", trainable=False):
""" Call default signature and return results """ Call default signature and return results
""" """
...@@ -137,16 +146,10 @@ class Module(object): ...@@ -137,16 +146,10 @@ class Module(object):
if op.has_attr("is_test"): if op.has_attr("is_test"):
op._set_attr("is_test", is_test) op._set_attr("is_test", is_test)
if not run_config:
run_config = RunConfig()
program = self.get_inference_program().clone() program = self.get_inference_program().clone()
_process_op_attr(program=program, is_test=False) _process_op_attr(program=program, is_test=False)
if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL: _set_param_trainable(program=program, trainable=trainable)
_set_param_trainable(program=program, trainable=True)
elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
_set_param_trainable(program=program, trainable=False)
return self.feed_target_names, self.fetch_targets, program return self.feed_target_names, self.fetch_targets, program
...@@ -282,79 +285,30 @@ class ModuleConfig(object): ...@@ -282,79 +285,30 @@ class ModuleConfig(object):
Load module config from module directory. Load module config from module directory.
""" """
#TODO(ZeyuChen): check module_desc.pb exsitance #TODO(ZeyuChen): check module_desc.pb exsitance
pb_path = os.path.join(self.module_dir, "module_desc.pb") with open(ModuleConfig.module_desc_path(self.module_dir), "rb") as fi:
with open(pb_path, "rb") as fi:
self.desc.ParseFromString(fi.read()) self.desc.ParseFromString(fi.read())
# print("self.desc.sign2var",
# self.desc.sign2var["default"].feed_desc[0].var_name)
if self.desc.contain_assets: if self.desc.contain_assets:
# load assets # load assets
assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
dict_path = os.path.join(assets_dir, DICT_NAME)
word_id = 0 word_id = 0
with open(ModuleConfig.assets_dict_path(self.module_dir)) as fi:
with open(dict_path) as fi:
words = fi.readlines() words = fi.readlines()
#TODO(ZeyuChen) check whether word id is duplicated and valid #TODO(ZeyuChen) check whether word id is duplicated and valid
for line in fi: for line in fi:
w, w_id = line.split() w, w_id = line.split()
self.dict[w] = int(w_id) self.dict[w] = int(w_id)
def dump(self):
""" Save Module configure file to disk.
"""
pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "wb") as fo:
fo.write(self.desc.SerializeToString())
# save assets/dictionary
assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
mkdir(assets_dir)
with open(os.path.join(assets_dir, DICT_NAME), "w") as fo:
for w in self.dict:
w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id))
def return_numpy(self): def return_numpy(self):
"""Return numpy or not according to the proto config. """Return numpy or not according to the proto config.
""" """
return self.desc.return_numpy return self.desc.return_numpy
def save_dict(self, word_dict, dict_name=DICT_NAME): def save_dict(self, word_dict, dict_name=DICT_FILENAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
for w in word_dict: for w in word_dict:
self.dict[w] = word_dict[w] self.dict[w] = word_dict[w]
def register_feed_signature(self, feed_desc, sign_name="default"):
""" Register feed signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in feed_desc:
feed = self.desc.sign2var[sign_name].feed_desc.add()
feed.key = k
feed.var_name = feed_desc[k]
def register_fetch_signature(self, fetch_desc, sign_name="default"):
""" Register fetch signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in fetch_desc:
fetch = self.desc.sign2var[sign_name].fetch_desc.add()
fetch.key = k
fetch.var_name = fetch_desc[k]
def feed_var_names(self, sign_name="default"): def feed_var_names(self, sign_name="default"):
return self.desc.sign2var[sign_name].feed_desc return self.desc.sign2var[sign_name].feed_desc
...@@ -377,6 +331,119 @@ class ModuleConfig(object): ...@@ -377,6 +331,119 @@ class ModuleConfig(object):
return desc.var_name return desc.var_name
raise Exception("fetch variable {} not found".format(key)) raise Exception("fetch variable {} not found".format(key))
@staticmethod
def module_desc_path(module_dir):
return os.path.join(module_dir, MODULE_DESC_PBNAME)
@staticmethod
def name_generator_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, GENERATOR_FILENAME)
@staticmethod
def assets_dict_path(module_dir):
assets_path = os.path.join(module_dir, ASSETS_DIRNAME)
mkdir(assets_path)
return os.path.join(assets_path, DICT_FILENAME)
@staticmethod
def meta_param_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, PARAM_FILENAME)
@staticmethod
def meta_name_generator_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, GENERATOR_FILENAME)
def create_module(sign_arr, program, module_dir=None, word_dict=None):
""" Create a module from main program
"""
assert isinstance(
program, fluid.Program), "program should be instance of fluid.Program"
assert sign_arr, "signature array should not be None"
if module_dir is None:
module_dir = os.path.join(".", "hub_module")
# create module path for saving
mkdir(module_dir)
module = module_desc_pb2.ModuleDesc()
program = program.clone()
if word_dict is None:
module.contain_assets = False
else:
module.contain_assets = True
with open(ModuleConfig.assets_dict_path(module_dir), "w") as fo:
for w in word_dict:
w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id))
# save the unique name generator object
generator = fluid.unique_name.generator
with open(ModuleConfig.name_generator_path(module_dir), "wb") as fo:
pickle.dump(generator, fo)
# save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters():
param_info = {
'name': param.name,
'regularizer': param.regularizer,
'gradient_clip_attr': param.gradient_clip_attr,
'trainable': param.trainable,
'optimize_attr': param.optimize_attr,
'do_model_average': param.do_model_average
}
param_arr.append(param_info)
with open(ModuleConfig.meta_param_path(module_dir), "wb") as fo:
pickle.dump(param_arr, fo)
# save signarture info
sign_map = module.sign2var
sign_arr = to_list(sign_arr)
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
var = sign_map[sign.get_name()]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
for input in sign.get_inputs():
feed_var = feed_desc.add()
feed_var.var_name = input.name
for output in sign.get_outputs():
fetch_var = fetch_desc.add()
fetch_var.var_name = output.name
# save inference program
exe = fluid.Executor(place=fluid.CPUPlace())
model_dir = os.path.join(module_dir, "model")
mkdir(model_dir)
# TODO(ZeyuChen): here only deal with one signature
first_sign = sign_arr[0]
fluid.io.save_inference_model(
model_dir,
feeded_var_names=[var.name for var in first_sign.get_inputs()],
target_vars=first_sign.get_outputs(),
main_program=program,
executor=exe)
# save to disk
data = module.SerializeToString()
with open(ModuleConfig.module_desc_path(module_dir), "wb") as f:
f.write(data)
class ModuleUtils(object): class ModuleUtils(object):
def __init__(self): def __init__(self):
...@@ -400,6 +467,7 @@ class ModuleUtils(object): ...@@ -400,6 +467,7 @@ class ModuleUtils(object):
block._remove_var("fetch") block._remove_var("fetch")
program.desc.flush() program.desc.flush()
# print("********************************")
# print(program) @staticmethod
# print("********************************") def module_desc_path(module_dir):
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle_hub.module_desc_pb2 as modulepb
import paddle.fluid as fluid
from paddle_hub.utils import to_list
from paddle_hub.signature import Signature
from paddle_hub.module import mkdir
import os
import pickle
def create_module(sign_arr, program, path=None, assets=None):
assert isinstance(
program, fluid.Program), "program should be instance of fluid.Program"
assert sign_arr, "signarture array should not be None"
if not path:
path = os.path.join(".", "hub_module")
# create module path for saving
mkdir(path)
module = modulepb.ModuleDesc()
program = program.clone()
# TODO(wuzewu): save assets data
if not assets:
module.contain_assets = False
else:
module.contain_assets = True
os.makedirs(os.path.join(path, "assets"))
# save the unique name object
generator = fluid.unique_name.generator
pklname = os.path.join(path, "uqn.pkl")
with open(pklname, "wb") as file:
pickle.dump(generator, file)
# save fluid Parameter
param_arr = []
for param in program.global_block().iter_parameters():
param_info = {
'name': param.name,
'regularizer': param.regularizer,
'gradient_clip_attr': param.gradient_clip_attr,
'trainable': param.trainable,
'optimize_attr': param.optimize_attr,
'do_model_average': param.do_model_average
}
param_arr.append(param_info)
pklname = os.path.join(path, "param.pkl")
with open(pklname, "wb") as file:
pickle.dump(param_arr, file)
# save signarture info
sign_map = module.sign2var
sign_arr = to_list(sign_arr)
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
var = sign_map[sign.get_name()]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
for input in sign.get_inputs():
feed_var = feed_desc.add()
feed_var.var_name = input.name
for output in sign.get_outputs():
fetch_var = fetch_desc.add()
fetch_var.var_name = output.name
# save inference program
exe = fluid.Executor(place=fluid.CPUPlace())
model_path = os.path.join(path, "model")
mkdir(model_path)
first_sign = sign_arr[0]
fluid.io.save_inference_model(
model_path,
feeded_var_names=[var.name for var in first_sign.get_inputs()],
target_vars=first_sign.get_outputs(),
main_program=program,
executor=exe)
# save to disk
data = module.SerializeToString()
metafile = os.path.join(path, "module_desc.pb")
with open(metafile, "wb") as f:
f.write(data)
...@@ -18,6 +18,9 @@ option optimize_for = LITE_RUNTIME; ...@@ -18,6 +18,9 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub; package paddle_hub;
message Version {
int64 version = 1;
}
// Feed Variable Description // Feed Variable Description
message FeedDesc { message FeedDesc {
string var_name = 1; string var_name = 1;
...@@ -47,5 +50,7 @@ message ModuleDesc { ...@@ -47,5 +50,7 @@ message ModuleDesc {
bool return_numpy = 3; bool return_numpy = 3;
bool contain_assets = 4; bool contain_assets = 4;
Version version = 5;
}; };
...@@ -17,10 +17,46 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -17,10 +17,46 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub', package='paddle_hub',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' '\n\x11module_desc.proto\x12\npaddle_hub\"\x1a\n\x07Version\x12\x0f\n\x07version\x18\x01 \x01(\x03\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xee\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12$\n\x07version\x18\x05 \x01(\x0b\x32\x13.paddle_hub.Version\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle_hub.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version',
full_name='paddle_hub.Version.version',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=33,
serialized_end=59,
)
_FEEDDESC = _descriptor.Descriptor( _FEEDDESC = _descriptor.Descriptor(
name='FeedDesc', name='FeedDesc',
full_name='paddle_hub.FeedDesc', full_name='paddle_hub.FeedDesc',
...@@ -53,8 +89,8 @@ _FEEDDESC = _descriptor.Descriptor( ...@@ -53,8 +89,8 @@ _FEEDDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=33, serialized_start=61,
serialized_end=61, serialized_end=89,
) )
_FETCHDESC = _descriptor.Descriptor( _FETCHDESC = _descriptor.Descriptor(
...@@ -89,8 +125,8 @@ _FETCHDESC = _descriptor.Descriptor( ...@@ -89,8 +125,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=63, serialized_start=91,
serialized_end=92, serialized_end=120,
) )
_MODULEVAR = _descriptor.Descriptor( _MODULEVAR = _descriptor.Descriptor(
...@@ -141,8 +177,8 @@ _MODULEVAR = _descriptor.Descriptor( ...@@ -141,8 +177,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=94, serialized_start=122,
serialized_end=189, serialized_end=217,
) )
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
...@@ -194,8 +230,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( ...@@ -194,8 +230,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=322, serialized_start=388,
serialized_end=392, serialized_end=458,
) )
_MODULEDESC = _descriptor.Descriptor( _MODULEDESC = _descriptor.Descriptor(
...@@ -269,6 +305,22 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -269,6 +305,22 @@ _MODULEDESC = _descriptor.Descriptor(
is_extension=False, is_extension=False,
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor(
name='version',
full_name='paddle_hub.ModuleDesc.version',
index=4,
number=5,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
], ],
extensions=[], extensions=[],
nested_types=[ nested_types=[
...@@ -280,8 +332,8 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -280,8 +332,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=192, serialized_start=220,
serialized_end=392, serialized_end=458,
) )
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC _MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
...@@ -289,11 +341,23 @@ _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC ...@@ -289,11 +341,23 @@ _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR _MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR
_MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC _MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY _MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY
_MODULEDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC
DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC
DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR
DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC
Version = _reflection.GeneratedProtocolMessageType(
'Version',
(_message.Message, ),
dict(
DESCRIPTOR=_VERSION,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.Version)
))
_sym_db.RegisterMessage(Version)
FeedDesc = _reflection.GeneratedProtocolMessageType( FeedDesc = _reflection.GeneratedProtocolMessageType(
'FeedDesc', 'FeedDesc',
(_message.Message, ), (_message.Message, ),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册