提交 5373f2d9 编写于 作者: Z Zeyu Chen

fix test_export_n_load_error

上级 3ab7da0c
...@@ -5,5 +5,5 @@ from __future__ import print_function ...@@ -5,5 +5,5 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle_hub.module import Module from paddle_hub.module import Module
from paddle_hub.module import ModuleDesc from paddle_hub.module import ModuleConfig
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
...@@ -42,11 +42,12 @@ class Module(object): ...@@ -42,11 +42,12 @@ class Module(object):
# donwload module # donwload module
if module_url.startswith("http"): if module_url.startswith("http"):
# if it's remote url link, then download and uncompress it # if it's remote url link, then download and uncompress it
module_name, module_dir = download_and_uncompress(module_url) self.module_name, self.module_dir = download_and_uncompress(
module_url)
else: else:
# otherwise it's local path, no need to deal with it # otherwise it's local path, no need to deal with it
module_dir = module_url self.module_dir = module_url
module_name = module_url.split()[-1] self.module_name = module_url.split()[-1]
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -62,6 +63,8 @@ class Module(object): ...@@ -62,6 +63,8 @@ class Module(object):
print("fetch_targets") print("fetch_targets")
print(self.fetch_targets) print(self.fetch_targets)
config = ModuleConfig()
config.load(self.module_dir)
# load assets # load assets
# self.dict = defaultdict(int) # self.dict = defaultdict(int)
# self.dict.setdefault(0) # self.dict.setdefault(0)
...@@ -144,20 +147,21 @@ class Module(object): ...@@ -144,20 +147,21 @@ class Module(object):
return lod_tensor return lod_tensor
def _word_id_mapping(self, inputs): def _word_id_mapping(self, inputs):
return list(map(lambda x: self.dict[x], inputs)) word_dict = self.config.get_dict()
return list(map(lambda x: word_dict[x], inputs))
# load assets folder
def _load_assets(self, module_dir): # # load assets folder
assets_dir = os.path.join(module_dir, ASSETS_NAME) # def _load_assets(self, module_dir):
dict_path = os.path.join(assets_dir, DICT_NAME) # assets_dir = os.path.join(module_dir, ASSETS_NAME)
word_id = 0 # dict_path = os.path.join(assets_dir, DICT_NAME)
# word_id = 0
with open(dict_path) as fi:
words = fi.readlines() # with open(dict_path) as fi:
#TODO(ZeyuChen) check whether word id is duplicated and valid # words = fi.readlines()
for line in fi: # #TODO(ZeyuChen) check whether word id is duplicated and valid
w, w_id = line.split() # for line in fi:
self.dict[w] = int(w_id) # w, w_id = line.split()
# self.dict[w] = int(w_id)
def add_module_feed_list(self, feed_list): def add_module_feed_list(self, feed_list):
self.feed_list = feed_list self.feed_list = feed_list
...@@ -167,10 +171,12 @@ class Module(object): ...@@ -167,10 +171,12 @@ class Module(object):
class ModuleConfig(object): class ModuleConfig(object):
def __init__(self, module_dir): def __init__(self, module_dir, module_name=None):
# generate model desc protobuf # generate model desc protobuf
self.module_dir = module_dir self.module_dir = module_dir
self.desc = module_desc_pb3.ModuleDesc() self.desc = module_desc_pb2.ModuleDesc()
if module_name == None:
module_name = module_dir.split("/")[-1]
self.desc.name = module_name self.desc.name = module_name
print("desc.name=", self.desc.name) print("desc.name=", self.desc.name)
self.desc.signature = "default" self.desc.signature = "default"
...@@ -178,7 +184,11 @@ class ModuleConfig(object): ...@@ -178,7 +184,11 @@ class ModuleConfig(object):
self.desc.contain_assets = True self.desc.contain_assets = True
print("desc.signature=", self.desc.contain_assets) print("desc.signature=", self.desc.contain_assets)
def load(module_dir): # init dict
self.dict = defaultdict(int)
self.dict.setdefault(0)
def load(self, module_dir):
"""load module config from module dir """load module config from module dir
""" """
#TODO(ZeyuChen): check module_desc.pb exsitance #TODO(ZeyuChen): check module_desc.pb exsitance
...@@ -187,9 +197,7 @@ class ModuleConfig(object): ...@@ -187,9 +197,7 @@ class ModuleConfig(object):
if self.desc.contain_assets: if self.desc.contain_assets:
# load assets # load assets
self.dict = defaultdict(int) assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
self.dict.setdefault(0)
assets_dir = os.path.join(self.module_dir, assets_dir)
dict_path = os.path.join(assets_dir, DICT_NAME) dict_path = os.path.join(assets_dir, DICT_NAME)
word_id = 0 word_id = 0
...@@ -200,28 +208,31 @@ class ModuleConfig(object): ...@@ -200,28 +208,31 @@ class ModuleConfig(object):
w, w_id = line.split() w, w_id = line.split()
self.dict[w] = int(w_id) self.dict[w] = int(w_id)
def dump(): def dump(self):
# save module_desc.proto first # save module_desc.proto first
pb_path = os.path.join(self.module, "module_desc.pb") pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "wb") as fo: with open(pb_path, "wb") as fo:
fo.write(self.desc.SerializeToString()) fo.write(self.desc.SerializeToString())
# save assets/dictionary # save assets/dictionary
assets_dir = os.path.join(self.module_dir, assets_dir) assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
mkdir(assets_dir) mkdir(assets_dir)
with open(os.path.join(assets_dir, DICT_NAME), "w") as fo: with open(os.path.join(assets_dir, DICT_NAME), "w") as fo:
for w in word_dict: for w in self.dict:
w_id = word_dict[w] w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id)) fo.write("{}\t{}\n".format(w, w_id))
def save_dict(word_dict, dict_name=DICT_NAME): def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
mkdir(path) mkdir(self.module_dir)
with open(os.path.join(self.module_dir, DICT_NAME), "w") as fo: with open(os.path.join(self.module_dir, DICT_NAME), "w") as fo:
for w in word_dict: for w in word_dict:
self.dict[w] = word_dict[w] self.dict[w] = word_dict[w]
def get_dict(self):
return self.dict
class ModuleUtils(object): class ModuleUtils(object):
def __init__(self): def __init__(self):
......
...@@ -24,6 +24,8 @@ import paddle_hub as hub ...@@ -24,6 +24,8 @@ import paddle_hub as hub
import unittest import unittest
import os import os
from collections import defaultdict
EMBED_SIZE = 16 EMBED_SIZE = 16
HIDDEN_SIZE = 256 HIDDEN_SIZE = 256
N = 5 N = 5
...@@ -42,8 +44,8 @@ def mock_data(): ...@@ -42,8 +44,8 @@ def mock_data():
yield d yield d
#batch_reader = paddle.batch(mock_data, BATCH_SIZE) batch_reader = paddle.batch(mock_data, BATCH_SIZE)
batch_reader = paddle.batch(data, BATCH_SIZE) #batch_reader = paddle.batch(data, BATCH_SIZE)
batch_size = 0 batch_size = 0
for d in batch_reader(): for d in batch_reader():
batch_size += 1 batch_size += 1
...@@ -158,7 +160,7 @@ def train(use_cuda=False): ...@@ -158,7 +160,7 @@ def train(use_cuda=False):
if step % 100 == 0: if step % 100 == 0:
print("Epoch={} Step={} Cost={}".format(epoch, step, cost[0])) print("Epoch={} Step={} Cost={}".format(epoch, step, cost[0]))
model_dir = "./w2v_model" model_dir = "./tmp/w2v_model"
# save part of model # save part of model
var_list_to_saved = [main_program.global_block().var("embedding")] var_list_to_saved = [main_program.global_block().var("embedding")]
print("saving model to %s" % model_dir) print("saving model to %s" % model_dir)
...@@ -169,23 +171,26 @@ def train(use_cuda=False): ...@@ -169,23 +171,26 @@ def train(use_cuda=False):
fluid.io.save_persistables( fluid.io.save_persistables(
executor=exe, dirname=model_dir + "_save_persistables") executor=exe, dirname=model_dir + "_save_persistables")
saved_model_path = "w2v_saved_inference_model" saved_model_dir = "./tmp/w2v_saved_inference_model"
# save inference model including feed and fetch variable info # save inference model including feed and fetch variable info
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=saved_model_path, dirname=saved_model_dir,
feeded_var_names=["firstw", "secondw", "thirdw", "fourthw"], feeded_var_names=["firstw", "secondw", "thirdw", "fourthw"],
target_vars=[predict_word], target_vars=[predict_word],
executor=exe) executor=exe)
dictionary = [] dictionary = defaultdict(int)
w_id = 0
for w in word_dict: for w in word_dict:
if isinstance(w, bytes): if isinstance(w, bytes):
w = w.decode("ascii") w = w.decode("ascii")
dictionary.append(w) dictionary[w] = w_id
w_id += 1
# save word dict to assets folder # save word dict to assets folder
hub.ModuleConfig.save_module_dict( config = hub.ModuleConfig(model_dir)
module_path=saved_model_path, word_dict=dictionary) config.save_dict(word_dict=dictionary)
config.dump()
def test_save_module(use_cuda=False): def test_save_module(use_cuda=False):
...@@ -200,8 +205,8 @@ def test_save_module(use_cuda=False): ...@@ -200,8 +205,8 @@ def test_save_module(use_cuda=False):
words, word_emb = module_fn() words, word_emb = module_fn()
exe.run(startup_program) exe.run(startup_program)
# load inference embedding parameters # load inference embedding parameters
saved_model_path = "./w2v_saved_inference_model" saved_model_dir = "./tmp/w2v_saved_inference_model"
fluid.io.load_inference_model(executor=exe, dirname=saved_model_path) fluid.io.load_inference_model(executor=exe, dirname=saved_model_dir)
feed_var_list = [main_program.global_block().var("words")] feed_var_list = [main_program.global_block().var("words")]
feeder = fluid.DataFeeder(feed_list=feed_var_list, place=place) feeder = fluid.DataFeeder(feed_list=feed_var_list, place=place)
...@@ -214,29 +219,32 @@ def test_save_module(use_cuda=False): ...@@ -214,29 +219,32 @@ def test_save_module(use_cuda=False):
np_result = np.array(results[0]) np_result = np.array(results[0])
print(np_result) print(np_result)
saved_module_dir = "./test/word2vec_inference_module" # save module_dir
saved_module_dir = "./tmp/word2vec_inference_module"
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=saved_module_dir, dirname=saved_module_dir,
feeded_var_names=["words"], feeded_var_names=["words"],
target_vars=[word_emb], target_vars=[word_emb],
executor=exe) executor=exe)
dictionary = [] dictionary = defaultdict(int)
for w in word_dict: w_id = 0
if isinstance(w, bytes): for w in word_dict:
w = w.decode("ascii") if isinstance(w, bytes):
dictionary.append(w) w = w.decode("ascii")
# save word dict to assets folder dictionary[w] = w_id
config = hub.ModuleConfig(saved_module_dir) w_id += 1
config.save_dict(word_dict=dictionary) # save word dict to assets folder
config = hub.ModuleConfig(saved_module_dir)
config.save_dict(word_dict=dictionary)
config.dump() config.dump()
def test_load_module(use_cuda=False): def test_load_module(use_cuda=False):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
saved_module_dir = "./test/word2vec_inference_module" saved_module_dir = "./tmp/word2vec_inference_module"
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
saved_module_dir, executor=exe) saved_module_dir, executor=exe)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册