提交 a8a0a8f1 编写于 作者: Z Zeyu Chen

add paddlehub version string, fix test case

上级 12e3dd4d
......@@ -199,8 +199,10 @@ def finetune_net(train_reader,
module_dir = os.path.join(save_dirname, network_name)
module = hub.Module(module_dir=module_dir)
feed_list, fetch_list, program = module(sign_name="default", trainable=True)
feed_list, fetch_list, program, generator = module(
sign_name="default", trainable=True)
with fluid.program_guard(main_program=program):
with fluid.unique_name.guard(generator):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config
......@@ -256,14 +258,10 @@ def finetune_net(train_reader,
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
# # save the model
# module_dir = os.path.join(save_dirname, network_name)
# signature = hub.create_signature(
# "default", inputs=[data], outputs=[sent_emb])
# hub.create_module(
# sign_arr=signature,
# program=fluid.default_main_program(),
# path=module_dir)
# save the model
model_dir = os.path.join(save_dirname, network_name + "_finetune")
fluid.io.save_persistables(
executor=exe, dirname=model_dir, main_program=None)
def eval_net(test_reader, use_gpu, model_path=None):
......
......@@ -24,3 +24,4 @@ from paddle_hub.module import ModuleUtils
from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature
from paddle_hub.version import __version__
......@@ -56,7 +56,6 @@ def md5file(fname):
def download_and_uncompress(url, save_name=None):
module_name = url.split("/")[-2]
dirname = os.path.join(MODULE_HOME, module_name)
print("download to dir", dirname)
if not os.path.exists(dirname):
os.makedirs(dirname)
......@@ -115,6 +114,3 @@ if __name__ == "__main__":
module_path = download_and_uncompress(link)
print("module path", module_path)
# dl = DownloadManager()
# dl.download_and_uncompress(link, "./tmp")
......@@ -29,6 +29,7 @@ from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list
from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
......@@ -259,13 +260,14 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
# create module path for saving
mkdir(module_dir)
module = module_desc_pb2.ModuleDesc()
module_desc = module_desc_pb2.ModuleDesc()
module_desc.version = __version__
program = program.clone()
if word_dict is None:
module.contain_assets = False
module_desc.contain_assets = False
else:
module.contain_assets = True
module_desc.contain_assets = True
with open(ModuleConfig.assets_dict_path(module_dir), "w") as fo:
for w in word_dict:
w_id = word_dict[w]
......@@ -301,7 +303,7 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
pickle.dump(param_arr, fo)
# save signarture info
sign_map = module.sign2var
sign_map = module_desc.sign2var
sign_arr = to_list(sign_arr)
for sign in sign_arr:
assert isinstance(sign,
......@@ -335,10 +337,10 @@ def create_module(sign_arr, program, module_dir=None, word_dict=None):
main_program=program,
executor=exe)
# save to disk
data = module.SerializeToString()
# Serialize module_desc pb
module_pb = module_desc.SerializeToString()
with open(ModuleConfig.module_desc_path(module_dir), "wb") as f:
f.write(data)
f.write(module_pb)
class ModuleUtils(object):
......@@ -363,7 +365,3 @@ class ModuleUtils(object):
block._remove_var("fetch")
program.desc.flush()
@staticmethod
def module_desc_path(module_dir):
pass
......@@ -18,9 +18,6 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub;
message Version {
int64 version = 1;
}
// Feed Variable Description
message FeedDesc {
string var_name = 1;
......@@ -51,6 +48,6 @@ message ModuleDesc {
bool contain_assets = 4;
Version version = 5;
string version = 5;
};
......@@ -17,46 +17,10 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub',
syntax='proto3',
serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\x1a\n\x07Version\x12\x0f\n\x07version\x18\x01 \x01(\x03\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xee\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12$\n\x07version\x18\x05 \x01(\x0b\x32\x13.paddle_hub.Version\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
'\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xd9\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\x0f\n\x07version\x18\x05 \x01(\t\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle_hub.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version',
full_name='paddle_hub.Version.version',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=33,
serialized_end=59,
)
_FEEDDESC = _descriptor.Descriptor(
name='FeedDesc',
full_name='paddle_hub.FeedDesc',
......@@ -89,8 +53,8 @@ _FEEDDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=61,
serialized_end=89,
serialized_start=33,
serialized_end=61,
)
_FETCHDESC = _descriptor.Descriptor(
......@@ -125,8 +89,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=91,
serialized_end=120,
serialized_start=63,
serialized_end=92,
)
_MODULEVAR = _descriptor.Descriptor(
......@@ -177,8 +141,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=122,
serialized_end=217,
serialized_start=94,
serialized_end=189,
)
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
......@@ -230,8 +194,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=388,
serialized_end=458,
serialized_start=339,
serialized_end=409,
)
_MODULEDESC = _descriptor.Descriptor(
......@@ -310,11 +274,11 @@ _MODULEDESC = _descriptor.Descriptor(
full_name='paddle_hub.ModuleDesc.version',
index=4,
number=5,
type=11,
cpp_type=10,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=None,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
......@@ -332,8 +296,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=220,
serialized_end=458,
serialized_start=192,
serialized_end=409,
)
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
......@@ -341,23 +305,11 @@ _MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR
_MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY
_MODULEDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC
DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC
DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR
DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC
Version = _reflection.GeneratedProtocolMessageType(
'Version',
(_message.Message, ),
dict(
DESCRIPTOR=_VERSION,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.Version)
))
_sym_db.RegisterMessage(Version)
FeedDesc = _reflection.GeneratedProtocolMessageType(
'FeedDesc',
(_message.Message, ),
......
......@@ -165,11 +165,11 @@ def test_create_w2v_module(use_gpu=False):
def test_load_w2v_module(use_gpu=False):
saved_module_dir = "./tmp/word2vec_test_module"
w2v_module = hub.Module(module_dir=saved_module_dir)
feed_list, fetch_list, program = w2v_module(
feed_list, fetch_list, program, generator = w2v_module(
sign_name="default", trainable=False)
with fluid.program_guard(main_program=program):
with fluid.unique_name.guard(generator):
pred_prob = fetch_list[0]
pred_word = fluid.layers.argmax(x=pred_prob, axis=1)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册