From 6f8fe813fdda1b922d47465299b549ed8b718b65 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Mon, 7 Jan 2019 02:15:14 +0800 Subject: [PATCH] update module, support dict input --- Senta/create_module.sh | 2 +- paddle_hub/module.py | 93 +++++++++++++++++++++++++++--------------- tests/test_module.py | 25 ++++++++---- 3 files changed, 80 insertions(+), 40 deletions(-) diff --git a/Senta/create_module.sh b/Senta/create_module.sh index 6996b565..ed33c256 100755 --- a/Senta/create_module.sh +++ b/Senta/create_module.sh @@ -1 +1 @@ -python test_create_hub.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models +python test_create_module.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 48af741a..cc4bbab4 100755 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# coding=utf-8 + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -82,25 +84,31 @@ class Module(object): # self.dict.setdefault(0) # self._load_assets(module_dir) - #TODO(ZeyuChen): Need add register more signature to execute different - # implmentation - def __call__(self, inputs=None, signature=None): - """ Call default signature and return results + def _construct_feed_dict(self, inputs): + """ Construct feed dict according to user's inputs and module config. """ - # TODO(ZeyuChen): add proto spec to check which task we need to run - # if it's NLP word embedding task, then do words preprocessing - # if it's image classification or image feature task do the other works + feed_dict = {} + for k in inputs: + if k in self.feed_target_names: + feed_dict[k] = inputs[k] - # if it's - word_ids_lod_tensor = self._process_input(inputs) - np_words_id = np.array(word_ids_lod_tensor) - print("word_ids_lod_tensor\n", np_words_id) + return feed_dict + + def __call__(self, inputs=None, sign_name="default"): + """ Call default signature and return results + """ + # word_ids_lod_tensor = self._preprocess_input(inputs) + feed_dict = self._construct_feed_dict(inputs) + print("feed_dict", feed_dict) + ret_numpy = self.config.return_numpy() + print("ret_numpy", ret_numpy) results = self.exe.run( self.inference_program, - feed={self.feed_target_names[0]: word_ids_lod_tensor}, + #feed={self.feed_target_names[0]: word_ids_lod_tensor}, + feed=feed_dict, fetch_list=self.fetch_targets, - return_numpy=False) # return_numpy=Flase is important + return_numpy=ret_numpy) print("module fetch_target_names", self.feed_target_names) print("module fetch_targets", self.fetch_targets) @@ -109,9 +117,15 @@ class Module(object): return np_result def get_vars(self): + """ + Return variable list of the module program + """ return self.inference_program.list_vars() def get_feed_var(self, key, signature="default"): + """ + Get feed variable according to variable key and signature + """ for var in self.inference_program.list_vars(): if var.name == self.config.feed_var_name(key, signature): return var @@ -119,6 +133,9 @@ class Module(object): raise Exception("Can't find input var {}".format(key)) def get_fetch_var(self, key, signature="default"): + """ + Get fetch variable according to variable key and signature + """ for var in self.inference_program.list_vars(): if var.name == self.config.fetch_var_name(key, signature): return var @@ -129,7 +146,7 @@ class Module(object): return self.inference_program # for text sequence input, transform to lod tensor as paddle graph's input - def _process_input(self, inputs): + def _preprocess_input(self, inputs): # words id mapping and dealing with oov # transform to lod tensor seq = [] @@ -167,17 +184,22 @@ class ModuleConfig(object): self.desc = module_desc_pb2.ModuleDesc() if module_name == None: module_name = module_dir.split("/")[-1] + # initialize module config default value self.desc.name = module_name - print("desc.name=", self.desc.name) self.desc.contain_assets = True - print("desc.signature=", self.desc.contain_assets) + self.desc.return_numpy = False # init dict self.dict = defaultdict(int) self.dict.setdefault(0) + def get_dict(self): + """ Return dictionary in Module""" + return self.dict + def load(self): - """load module config from module dir + """ + Load module config from module directory. """ #TODO(ZeyuChen): check module_desc.pb exsitance pb_path = os.path.join(self.module_dir, "module_desc.pb") @@ -198,8 +220,7 @@ class ModuleConfig(object): self.dict[w] = int(w_id) def dump(self): - """ - save module_desc.proto first + """ Save Module configure file to disk. """ pb_path = os.path.join(self.module_dir, "module_desc.pb") with open(pb_path, "wb") as fo: @@ -213,6 +234,11 @@ class ModuleConfig(object): w_id = self.dict[w] fo.write("{}\t{}\n".format(w, w_id)) + def return_numpy(self): + """Return numpy or not according to the proto config. + """ + return self.desc.return_numpy + def save_dict(self, word_dict, dict_name=DICT_NAME): """ Save dictionary for NLP module """ @@ -223,10 +249,13 @@ class ModuleConfig(object): # for w in word_dict: # self.dict[w] = word_dict[w] - def get_dict(self): - return self.dict - def register_feed_signature(self, feed_desc, sign_name="default"): + """ Register feed signature to the Module + + Args: + fetch_desc: a dictionary of signature to input variable + sign_name: signature name, use "default" as default signature + """ #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated for k in feed_desc: feed = self.desc.sign2var[sign_name].feed_desc.add() @@ -234,6 +263,12 @@ class ModuleConfig(object): feed.var_name = feed_desc[k] def register_fetch_signature(self, fetch_desc, sign_name="default"): + """ Register fetch signature to the Module + + Args: + fetch_desc: a dictionary of signature to input variable + sign_name: signature name, use "default" as default signature + """ #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated for k in fetch_desc: fetch = self.desc.sign2var[sign_name].fetch_desc.add() @@ -241,12 +276,16 @@ class ModuleConfig(object): fetch.var_name = fetch_desc[k] def feed_var_name(self, key, sign_name="default"): + """get module's feed/input variable name + """ for desc in self.desc.sign2var[sign_name].feed_desc: if desc.key == key: return desc.var_name raise Exception("feed variable {} not found".format(key)) def fetch_var_name(self, key, sign_name="default"): + """get module's fetch/output variable name + """ for desc in self.desc.sign2var[sign_name].fetch_desc: if desc.key == key: return desc.var_name @@ -278,13 +317,3 @@ class ModuleUtils(object): # print("********************************") # print(program) # print("********************************") - - -if __name__ == "__main__": - url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" - m = Module(module_url=url) - inputs = [["it", "is", "new"], ["hello", "world"]] - #tensor = m._process_input(inputs) - #print(tensor) - result = m(inputs) - print(result) diff --git a/tests/test_module.py b/tests/test_module.py index a639da50..1754a179 100755 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -17,14 +17,25 @@ import paddle_hub as hub class TestModule(unittest.TestCase): + #TODO(ZeyuChen): add setup for test envrinoment prepration def test_word2vec_module_usage(self): - url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" - module = Module(module_url=url) - inputs = [["it", "is", "new"], ["hello", "world"]] - tensor = module._process_input(inputs) - print(tensor) - result = module(inputs) - print(result) + pass + # url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz" + # module = Module(module_url=url) + # inputs = [["it", "is", "new"], ["hello", "world"]] + # tensor = module._process_input(inputs) + # print(tensor) + # result = module(inputs) + # print(result) + + def test_senta_module_usage(self): + pass + # m = Module(module_dir="./models/bow_net") + # inputs = [["外人", "爸妈", "翻车"], ["金钱", "电量"]] + # tensor = m._preprocess_input(inputs) + # print(tensor) + # result = m({"words": tensor}) + # print(result) if __name__ == "__main__": -- GitLab