提交 62fd63c9 编写于 作者: Z Zeyu Chen

add module_desc.proto and fix senta example usage issues

上级 5373f2d9
......@@ -152,10 +152,12 @@ def train_net(train_reader,
(pass_id, avg_acc, avg_cost))
# save the model
module_path = os.path.join(save_dirname, network_name)
hub.ModuleDesc.save_module_dict(
module_path=module_path, word_dict=word_dict)
fluid.io.save_inference_model(module_path, ["words"], emb, exe)
module_dir = os.path.join(save_dirname, network_name)
fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
config = hub.ModuleConfig(module_dir)
config.save_dict(word_dict=word_dict)
config.dump()
def retrain_net(train_reader,
......@@ -209,10 +211,7 @@ def retrain_net(train_reader,
#TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_module_output()
print(
"adfjkajdlfjoqi jqiorejlmsfdlkjoi jqwierjoajsdklfjoi qjerijoajdfiqwjeor adfkalsf"
)
# # # embedding layer
# # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# # bow layer
......@@ -264,8 +263,8 @@ def retrain_net(train_reader,
# print("senta_load_module", fluid.default_main_program())
# save the model
module_path = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(module_path, ["words"], emb, exe)
module_dir = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
def eval_net(test_reader, use_gpu, model_path=None):
......
......@@ -6,4 +6,5 @@ import paddle.fluid as fluid
from paddle_hub.module import Module
from paddle_hub.module import ModuleConfig
from paddle_hub.module import ModuleUtils
from paddle_hub.downloader import download_and_uncompress
......@@ -46,15 +46,16 @@ class Module(object):
module_url)
else:
# otherwise it's local path, no need to deal with it
print("Module.__init__", module_url)
self.module_dir = module_url
self.module_name = module_url.split()[-1]
self.module_name = module_url.split("/")[-1]
# load paddle inference model
place = fluid.CPUPlace()
self.exe = fluid.Executor(fluid.CPUPlace())
[self.inference_program, self.feed_target_names,
self.fetch_targets] = fluid.io.load_inference_model(
dirname=module_dir, executor=self.exe)
dirname=self.module_dir, executor=self.exe)
print("inference_program")
print(self.inference_program)
......@@ -63,8 +64,8 @@ class Module(object):
print("fetch_targets")
print(self.fetch_targets)
config = ModuleConfig()
config.load(self.module_dir)
config = ModuleConfig(self.module_dir)
config.load()
# load assets
# self.dict = defaultdict(int)
# self.dict.setdefault(0)
......@@ -188,11 +189,12 @@ class ModuleConfig(object):
self.dict = defaultdict(int)
self.dict.setdefault(0)
def load(self, module_dir):
def load(self):
"""load module config from module dir
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
with open(pb_file_path, "rb") as fi:
pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "rb") as fi:
self.desc.ParseFromString(fi.read())
if self.desc.contain_assets:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册