提交 62fd63c9 编写于 作者: Z Zeyu Chen

add module_desc.proto and fix senta example usage issues

上级 5373f2d9
...@@ -152,10 +152,12 @@ def train_net(train_reader, ...@@ -152,10 +152,12 @@ def train_net(train_reader,
(pass_id, avg_acc, avg_cost)) (pass_id, avg_acc, avg_cost))
# save the model # save the model
module_path = os.path.join(save_dirname, network_name) module_dir = os.path.join(save_dirname, network_name)
hub.ModuleDesc.save_module_dict( fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
module_path=module_path, word_dict=word_dict)
fluid.io.save_inference_model(module_path, ["words"], emb, exe) config = hub.ModuleConfig(module_dir)
config.save_dict(word_dict=word_dict)
config.dump()
def retrain_net(train_reader, def retrain_net(train_reader,
...@@ -209,10 +211,7 @@ def retrain_net(train_reader, ...@@ -209,10 +211,7 @@ def retrain_net(train_reader,
#TODO(ZeyuChen): how to get output paramter according to proto config #TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_module_output() emb = module.get_module_output()
print( # # embedding layer
"adfjkajdlfjoqi jqiorejlmsfdlkjoi jqwierjoajsdklfjoi qjerijoajdfiqwjeor adfkalsf"
)
# # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) # emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding") # #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# # bow layer # # bow layer
...@@ -264,8 +263,8 @@ def retrain_net(train_reader, ...@@ -264,8 +263,8 @@ def retrain_net(train_reader,
# print("senta_load_module", fluid.default_main_program()) # print("senta_load_module", fluid.default_main_program())
# save the model # save the model
module_path = os.path.join(save_dirname, network_name + "_retrain") module_dir = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(module_path, ["words"], emb, exe) fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
def eval_net(test_reader, use_gpu, model_path=None): def eval_net(test_reader, use_gpu, model_path=None):
......
...@@ -6,4 +6,5 @@ import paddle.fluid as fluid ...@@ -6,4 +6,5 @@ import paddle.fluid as fluid
from paddle_hub.module import Module from paddle_hub.module import Module
from paddle_hub.module import ModuleConfig from paddle_hub.module import ModuleConfig
from paddle_hub.module import ModuleUtils
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
...@@ -46,15 +46,16 @@ class Module(object): ...@@ -46,15 +46,16 @@ class Module(object):
module_url) module_url)
else: else:
# otherwise it's local path, no need to deal with it # otherwise it's local path, no need to deal with it
print("Module.__init__", module_url)
self.module_dir = module_url self.module_dir = module_url
self.module_name = module_url.split()[-1] self.module_name = module_url.split("/")[-1]
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.exe = fluid.Executor(fluid.CPUPlace()) self.exe = fluid.Executor(fluid.CPUPlace())
[self.inference_program, self.feed_target_names, [self.inference_program, self.feed_target_names,
self.fetch_targets] = fluid.io.load_inference_model( self.fetch_targets] = fluid.io.load_inference_model(
dirname=module_dir, executor=self.exe) dirname=self.module_dir, executor=self.exe)
print("inference_program") print("inference_program")
print(self.inference_program) print(self.inference_program)
...@@ -63,8 +64,8 @@ class Module(object): ...@@ -63,8 +64,8 @@ class Module(object):
print("fetch_targets") print("fetch_targets")
print(self.fetch_targets) print(self.fetch_targets)
config = ModuleConfig() config = ModuleConfig(self.module_dir)
config.load(self.module_dir) config.load()
# load assets # load assets
# self.dict = defaultdict(int) # self.dict = defaultdict(int)
# self.dict.setdefault(0) # self.dict.setdefault(0)
...@@ -188,11 +189,12 @@ class ModuleConfig(object): ...@@ -188,11 +189,12 @@ class ModuleConfig(object):
self.dict = defaultdict(int) self.dict = defaultdict(int)
self.dict.setdefault(0) self.dict.setdefault(0)
def load(self, module_dir): def load(self):
"""load module config from module dir """load module config from module dir
""" """
#TODO(ZeyuChen): check module_desc.pb exsitance #TODO(ZeyuChen): check module_desc.pb exsitance
with open(pb_file_path, "rb") as fi: pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "rb") as fi:
self.desc.ParseFromString(fi.read()) self.desc.ParseFromString(fi.read())
if self.desc.contain_assets: if self.desc.contain_assets:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册