提交 6f8fe813 编写于 作者: Z Zeyu Chen

update module, support dict input

上级 0bf02fec
python test_create_hub.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models python test_create_module.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# coding=utf-8
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -82,25 +84,31 @@ class Module(object): ...@@ -82,25 +84,31 @@ class Module(object):
# self.dict.setdefault(0) # self.dict.setdefault(0)
# self._load_assets(module_dir) # self._load_assets(module_dir)
#TODO(ZeyuChen): Need add register more signature to execute different def _construct_feed_dict(self, inputs):
# implmentation """ Construct feed dict according to user's inputs and module config.
def __call__(self, inputs=None, signature=None):
""" Call default signature and return results
""" """
# TODO(ZeyuChen): add proto spec to check which task we need to run feed_dict = {}
# if it's NLP word embedding task, then do words preprocessing for k in inputs:
# if it's image classification or image feature task do the other works if k in self.feed_target_names:
feed_dict[k] = inputs[k]
# if it's return feed_dict
word_ids_lod_tensor = self._process_input(inputs)
np_words_id = np.array(word_ids_lod_tensor) def __call__(self, inputs=None, sign_name="default"):
print("word_ids_lod_tensor\n", np_words_id) """ Call default signature and return results
"""
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict = self._construct_feed_dict(inputs)
print("feed_dict", feed_dict)
ret_numpy = self.config.return_numpy()
print("ret_numpy", ret_numpy)
results = self.exe.run( results = self.exe.run(
self.inference_program, self.inference_program,
feed={self.feed_target_names[0]: word_ids_lod_tensor}, #feed={self.feed_target_names[0]: word_ids_lod_tensor},
feed=feed_dict,
fetch_list=self.fetch_targets, fetch_list=self.fetch_targets,
return_numpy=False) # return_numpy=Flase is important return_numpy=ret_numpy)
print("module fetch_target_names", self.feed_target_names) print("module fetch_target_names", self.feed_target_names)
print("module fetch_targets", self.fetch_targets) print("module fetch_targets", self.fetch_targets)
...@@ -109,9 +117,15 @@ class Module(object): ...@@ -109,9 +117,15 @@ class Module(object):
return np_result return np_result
def get_vars(self): def get_vars(self):
"""
Return variable list of the module program
"""
return self.inference_program.list_vars() return self.inference_program.list_vars()
def get_feed_var(self, key, signature="default"): def get_feed_var(self, key, signature="default"):
"""
Get feed variable according to variable key and signature
"""
for var in self.inference_program.list_vars(): for var in self.inference_program.list_vars():
if var.name == self.config.feed_var_name(key, signature): if var.name == self.config.feed_var_name(key, signature):
return var return var
...@@ -119,6 +133,9 @@ class Module(object): ...@@ -119,6 +133,9 @@ class Module(object):
raise Exception("Can't find input var {}".format(key)) raise Exception("Can't find input var {}".format(key))
def get_fetch_var(self, key, signature="default"): def get_fetch_var(self, key, signature="default"):
"""
Get fetch variable according to variable key and signature
"""
for var in self.inference_program.list_vars(): for var in self.inference_program.list_vars():
if var.name == self.config.fetch_var_name(key, signature): if var.name == self.config.fetch_var_name(key, signature):
return var return var
...@@ -129,7 +146,7 @@ class Module(object): ...@@ -129,7 +146,7 @@ class Module(object):
return self.inference_program return self.inference_program
# for text sequence input, transform to lod tensor as paddle graph's input # for text sequence input, transform to lod tensor as paddle graph's input
def _process_input(self, inputs): def _preprocess_input(self, inputs):
# words id mapping and dealing with oov # words id mapping and dealing with oov
# transform to lod tensor # transform to lod tensor
seq = [] seq = []
...@@ -167,17 +184,22 @@ class ModuleConfig(object): ...@@ -167,17 +184,22 @@ class ModuleConfig(object):
self.desc = module_desc_pb2.ModuleDesc() self.desc = module_desc_pb2.ModuleDesc()
if module_name == None: if module_name == None:
module_name = module_dir.split("/")[-1] module_name = module_dir.split("/")[-1]
# initialize module config default value
self.desc.name = module_name self.desc.name = module_name
print("desc.name=", self.desc.name)
self.desc.contain_assets = True self.desc.contain_assets = True
print("desc.signature=", self.desc.contain_assets) self.desc.return_numpy = False
# init dict # init dict
self.dict = defaultdict(int) self.dict = defaultdict(int)
self.dict.setdefault(0) self.dict.setdefault(0)
def get_dict(self):
""" Return dictionary in Module"""
return self.dict
def load(self): def load(self):
"""load module config from module dir """
Load module config from module directory.
""" """
#TODO(ZeyuChen): check module_desc.pb exsitance #TODO(ZeyuChen): check module_desc.pb exsitance
pb_path = os.path.join(self.module_dir, "module_desc.pb") pb_path = os.path.join(self.module_dir, "module_desc.pb")
...@@ -198,8 +220,7 @@ class ModuleConfig(object): ...@@ -198,8 +220,7 @@ class ModuleConfig(object):
self.dict[w] = int(w_id) self.dict[w] = int(w_id)
def dump(self): def dump(self):
""" """ Save Module configure file to disk.
save module_desc.proto first
""" """
pb_path = os.path.join(self.module_dir, "module_desc.pb") pb_path = os.path.join(self.module_dir, "module_desc.pb")
with open(pb_path, "wb") as fo: with open(pb_path, "wb") as fo:
...@@ -213,6 +234,11 @@ class ModuleConfig(object): ...@@ -213,6 +234,11 @@ class ModuleConfig(object):
w_id = self.dict[w] w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id)) fo.write("{}\t{}\n".format(w, w_id))
def return_numpy(self):
"""Return numpy or not according to the proto config.
"""
return self.desc.return_numpy
def save_dict(self, word_dict, dict_name=DICT_NAME): def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
...@@ -223,10 +249,13 @@ class ModuleConfig(object): ...@@ -223,10 +249,13 @@ class ModuleConfig(object):
# for w in word_dict: # for w in word_dict:
# self.dict[w] = word_dict[w] # self.dict[w] = word_dict[w]
def get_dict(self):
return self.dict
def register_feed_signature(self, feed_desc, sign_name="default"): def register_feed_signature(self, feed_desc, sign_name="default"):
""" Register feed signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in feed_desc: for k in feed_desc:
feed = self.desc.sign2var[sign_name].feed_desc.add() feed = self.desc.sign2var[sign_name].feed_desc.add()
...@@ -234,6 +263,12 @@ class ModuleConfig(object): ...@@ -234,6 +263,12 @@ class ModuleConfig(object):
feed.var_name = feed_desc[k] feed.var_name = feed_desc[k]
def register_fetch_signature(self, fetch_desc, sign_name="default"): def register_fetch_signature(self, fetch_desc, sign_name="default"):
""" Register fetch signature to the Module
Args:
fetch_desc: a dictionary of signature to input variable
sign_name: signature name, use "default" as default signature
"""
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in fetch_desc: for k in fetch_desc:
fetch = self.desc.sign2var[sign_name].fetch_desc.add() fetch = self.desc.sign2var[sign_name].fetch_desc.add()
...@@ -241,12 +276,16 @@ class ModuleConfig(object): ...@@ -241,12 +276,16 @@ class ModuleConfig(object):
fetch.var_name = fetch_desc[k] fetch.var_name = fetch_desc[k]
def feed_var_name(self, key, sign_name="default"): def feed_var_name(self, key, sign_name="default"):
"""get module's feed/input variable name
"""
for desc in self.desc.sign2var[sign_name].feed_desc: for desc in self.desc.sign2var[sign_name].feed_desc:
if desc.key == key: if desc.key == key:
return desc.var_name return desc.var_name
raise Exception("feed variable {} not found".format(key)) raise Exception("feed variable {} not found".format(key))
def fetch_var_name(self, key, sign_name="default"): def fetch_var_name(self, key, sign_name="default"):
"""get module's fetch/output variable name
"""
for desc in self.desc.sign2var[sign_name].fetch_desc: for desc in self.desc.sign2var[sign_name].fetch_desc:
if desc.key == key: if desc.key == key:
return desc.var_name return desc.var_name
...@@ -278,13 +317,3 @@ class ModuleUtils(object): ...@@ -278,13 +317,3 @@ class ModuleUtils(object):
# print("********************************") # print("********************************")
# print(program) # print(program)
# print("********************************") # print("********************************")
if __name__ == "__main__":
url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
m = Module(module_url=url)
inputs = [["it", "is", "new"], ["hello", "world"]]
#tensor = m._process_input(inputs)
#print(tensor)
result = m(inputs)
print(result)
...@@ -17,14 +17,25 @@ import paddle_hub as hub ...@@ -17,14 +17,25 @@ import paddle_hub as hub
class TestModule(unittest.TestCase): class TestModule(unittest.TestCase):
#TODO(ZeyuChen): add setup for test envrinoment prepration
def test_word2vec_module_usage(self): def test_word2vec_module_usage(self):
url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" pass
module = Module(module_url=url) # url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
inputs = [["it", "is", "new"], ["hello", "world"]] # module = Module(module_url=url)
tensor = module._process_input(inputs) # inputs = [["it", "is", "new"], ["hello", "world"]]
print(tensor) # tensor = module._process_input(inputs)
result = module(inputs) # print(tensor)
print(result) # result = module(inputs)
# print(result)
def test_senta_module_usage(self):
pass
# m = Module(module_dir="./models/bow_net")
# inputs = [["外人", "爸妈", "翻车"], ["金钱", "电量"]]
# tensor = m._preprocess_input(inputs)
# print(tensor)
# result = m({"words": tensor})
# print(result)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册