提交 6f8fe813 编写于 作者: Z Zeyu Chen

update module, support dict input

上级 0bf02fec
python test_create_hub.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models
python test_create_module.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -82,25 +84,31 @@ class Module(object):
# self.dict.setdefault(0)
# self._load_assets(module_dir)
#TODO(ZeyuChen): Need add register more signature to execute different
# implmentation
def __call__(self, inputs=None, signature=None):
""" Call default signature and return results
def _construct_feed_dict(self, inputs):
""" Construct feed dict according to user's inputs and module config.
"""
# TODO(ZeyuChen): add proto spec to check which task we need to run
# if it's NLP word embedding task, then do words preprocessing
# if it's image classification or image feature task do the other works
feed_dict = {}
for k in inputs:
if k in self.feed_target_names:
feed_dict[k] = inputs[k]
# if it's
word_ids_lod_tensor = self._process_input(inputs)
np_words_id = np.array(word_ids_lod_tensor)
print("word_ids_lod_tensor\n", np_words_id)
return feed_dict
def __call__(self, inputs=None, sign_name="default"):
""" Call default signature and return results
"""
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict = self._construct_feed_dict(inputs)
print("feed_dict", feed_dict)
ret_numpy = self.config.return_numpy()
print("ret_numpy", ret_numpy)
results = self.exe.run(
self.inference_program,
feed={self.feed_target_names[0]: word_ids_lod_tensor},
#feed={self.feed_target_names[0]: word_ids_lod_tensor},
feed=feed_dict,
fetch_list=self.fetch_targets,
return_numpy=False) # return_numpy=Flase is important
return_numpy=ret_numpy)
print("module fetch_target_names", self.feed_target_names)
print("module fetch_targets", self.fetch_targets)
......@@ -109,9 +117,15 @@ class Module(object):
return np_result
def get_vars(self):
"""
Return variable list of the module program
"""
return self.inference_program.list_vars()
def get_feed_var(self, key, signature="default"):
"""
Get feed variable according to variable key and signature
"""
for var in self.inference_program.list_vars():
if var.name == self.config.feed_var_name(key, signature):
return var
......@@ -119,6 +133,9 @@ class Module(object):
raise Exception("Can't find input var {}".format(key))
def get_fetch_var(self, key, signature="default"):
"""
Get fetch variable according to variable key and signature
"""
for var in self.inference_program.list_vars():
if var.name == self.config.fetch_var_name(key, signature):
return var
......@@ -129,7 +146,7 @@ class Module(object):
return self.inference_program
# for text sequence input, transform to lod tensor as paddle graph's input
def _process_input(self, inputs):
def _preprocess_input(self, inputs):
# words id mapping and dealing with oov
# transform to lod tensor
seq = []
......@@ -167,17 +184,22 @@ class ModuleConfig(object):
self.desc = module_desc_pb2.ModuleDesc()
if module_name == None:
module_name = module_dir.split("/")[-1]
# initialize module config default value
self.desc.name = module_name
print("desc.name=", self.desc.name)
self.desc.contain_assets = True
print("desc.signature=", self.desc.contain_assets)
self.desc.return_numpy = False
# init dict
self.dict = defaultdict(int)
self.dict.setdefault(0)
def get_dict(self):
""" Return dictionary in Module"""
return self.dict
def load(self):
"""load module config from module dir
"""
Load module config from module directory.
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
pb_path = os.path.join(self.module_dir, "module_desc.pb")
......@@ -198,8 +220,7 @@ class ModuleConfig(object):
self.dict[w] = int(w_id)
def dump(self):
"""
save module_desc.proto first
""" Save Module configure file to disk.
"""
pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "wb") as fo:
......@@ -213,6 +234,11 @@ class ModuleConfig(object):
w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id))
def return_numpy(self):
"""Return numpy or not according to the proto config.
"""
return self.desc.return_numpy
def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module
"""
......@@ -223,10 +249,13 @@ class ModuleConfig(object):
# for w in word_dict:
# self.dict[w] = word_dict[w]
def get_dict(self):
return self.dict
def register_feed_signature(self, feed_desc, sign_name="default"):
""" Register feed signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in feed_desc:
feed = self.desc.sign2var[sign_name].feed_desc.add()
......@@ -234,6 +263,12 @@ class ModuleConfig(object):
feed.var_name = feed_desc[k]
def register_fetch_signature(self, fetch_desc, sign_name="default"):
""" Register fetch signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in fetch_desc:
fetch = self.desc.sign2var[sign_name].fetch_desc.add()
......@@ -241,12 +276,16 @@ class ModuleConfig(object):
fetch.var_name = fetch_desc[k]
def feed_var_name(self, key, sign_name="default"):
"""get module's feed/input variable name
"""
for desc in self.desc.sign2var[sign_name].feed_desc:
if desc.key == key:
return desc.var_name
raise Exception("feed variable {} not found".format(key))
def fetch_var_name(self, key, sign_name="default"):
"""get module's fetch/output variable name
"""
for desc in self.desc.sign2var[sign_name].fetch_desc:
if desc.key == key:
return desc.var_name
......@@ -278,13 +317,3 @@ class ModuleUtils(object):
# print("********************************")
# print(program)
# print("********************************")
if __name__ == "__main__":
url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
m = Module(module_url=url)
inputs = [["it", "is", "new"], ["hello", "world"]]
#tensor = m._process_input(inputs)
#print(tensor)
result = m(inputs)
print(result)
......@@ -17,14 +17,25 @@ import paddle_hub as hub
class TestModule(unittest.TestCase):
#TODO(ZeyuChen): add setup for test envrinoment prepration
def test_word2vec_module_usage(self):
url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
module = Module(module_url=url)
inputs = [["it", "is", "new"], ["hello", "world"]]
tensor = module._process_input(inputs)
print(tensor)
result = module(inputs)
print(result)
pass
# url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
# module = Module(module_url=url)
# inputs = [["it", "is", "new"], ["hello", "world"]]
# tensor = module._process_input(inputs)
# print(tensor)
# result = module(inputs)
# print(result)
def test_senta_module_usage(self):
pass
# m = Module(module_dir="./models/bow_net")
# inputs = [["外人", "爸妈", "翻车"], ["金钱", "电量"]]
# tensor = m._preprocess_input(inputs)
# print(tensor)
# result = m({"words": tensor})
# print(result)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册