From 2f2fe7afecd5695e20968f7867d6d5cb5c7b4ee9 Mon Sep 17 00:00:00 2001 From: tianxin Date: Wed, 17 Jul 2019 10:27:42 +0800 Subject: [PATCH] do classify_inference by infer_program fix #207 --- ERNIE/README.md | 2 +- ERNIE/finetune/classifier.py | 2 +- ERNIE/predict_classifier.py | 59 ++++++++++++++++++++----------- ERNIE/reader/task_reader.py | 68 ++++++++++++++++++++++-------------- 4 files changed, 83 insertions(+), 48 deletions(-) diff --git a/ERNIE/README.md b/ERNIE/README.md index 81415c3..b47e856 100644 --- a/ERNIE/README.md +++ b/ERNIE/README.md @@ -316,4 +316,4 @@ python -u predict_classifier.py \ 实际使用时,需要通过 `init_checkpoint` 指定预测用的模型,通过 `predict_set` 指定待预测的数据文件,通过 `num_labels` 配置分类的类别数目; -**Note**: predict_set 的数据格式与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可; +**Note**: predict_set 的数据格式是由 text_a、text_b(可选) 组成的1列/2列 tsv 文件; diff --git a/ERNIE/finetune/classifier.py b/ERNIE/finetune/classifier.py index e7a79ac..c69deaf 100644 --- a/ERNIE/finetune/classifier.py +++ b/ERNIE/finetune/classifier.py @@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): if is_prediction: probs = fluid.layers.softmax(logits) feed_targets_name = [ - src_ids.name, pos_ids.name, sent_ids.name, input_mask.name + src_ids.name, sent_ids.name, pos_ids.name, input_mask.name ] return pyreader, probs, feed_targets_name diff --git a/ERNIE/predict_classifier.py b/ERNIE/predict_classifier.py index 924d4fa..ecda8a3 100644 --- a/ERNIE/predict_classifier.py +++ b/ERNIE/predict_classifier.py @@ -37,6 +37,7 @@ parser = argparse.ArgumentParser(__doc__) model_g = ArgumentGroup(parser, "model", "options to init, resume and save model.") model_g.add_arg("ernie_config_path", str, None, "Path to the json file for bert model config.") model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") +model_g.add_arg("save_inference_model_path", str, "inference_model", "If set, save the inference model to this path.") model_g.add_arg("use_fp16", bool, False, "Whether to resume parameters from fp16 checkpoint.") model_g.add_arg("num_labels", int, 2, "num labels for classify") @@ -65,7 +66,8 @@ def main(args): label_map_config=args.label_map_config, max_seq_len=args.max_seq_len, do_lower_case=args.do_lower_case, - in_tokens=False) + in_tokens=False, + is_inference=True) predict_prog = fluid.Program() predict_startup = fluid.Program() @@ -95,7 +97,25 @@ def main(args): else: raise ValueError("args 'init_checkpoint' should be set for prediction!") - predict_exe = fluid.Executor(place) + assert args.save_inference_model_path, "args save_inference_model_path should be set for prediction" + _, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/')) + dir_name = ckpt_dir + '_inference_model' + model_path = os.path.join(args.save_inference_model_path, dir_name) + print("save inference model to %s" % model_path) + fluid.io.save_inference_model( + model_path, + feed_target_names, [probs], + exe, + main_program=predict_prog) + + print("load inference model from %s" % model_path) + infer_program, feed_target_names, probs = fluid.io.load_inference_model( + model_path, exe) + + src_ids = feed_target_names[0] + sent_ids = feed_target_names[1] + pos_ids = feed_target_names[2] + input_mask = feed_target_names[3] predict_data_generator = reader.data_generator( input_file=args.predict_set, @@ -103,25 +123,24 @@ def main(args): epoch=1, shuffle=False) - predict_pyreader.decorate_tensor_provider(predict_data_generator) - - predict_pyreader.start() - all_results = [] - time_begin = time.time() - while True: - try: - results = predict_exe.run(program=predict_prog, fetch_list=[probs.name]) - all_results.extend(results[0]) - except fluid.core.EOFException: - predict_pyreader.reset() - break - time_end = time.time() - - np.set_printoptions(precision=4, suppress=True) print("-------------- prediction results --------------") - for index, result in enumerate(all_results): - print(str(index) + '\t{}'.format(result)) - + np.set_printoptions(precision=4, suppress=True) + index = 0 + for sample in predict_data_generator(): + src_ids_data = sample[0] + sent_ids_data = sample[1] + pos_ids_data = sample[2] + input_mask_data = sample[3] + output = exe.run( + infer_program, + feed={src_ids: src_ids_data, + sent_ids: sent_ids_data, + pos_ids: pos_ids_data, + input_mask: input_mask_data}, + fetch_list=probs) + for single_result in output[0]: + print("example_index:{}\t{}".format(index, single_result)) + index += 1 if __name__ == '__main__': print_arguments(args) diff --git a/ERNIE/reader/task_reader.py b/ERNIE/reader/task_reader.py index 5dab909..c58c9a3 100644 --- a/ERNIE/reader/task_reader.py +++ b/ERNIE/reader/task_reader.py @@ -28,6 +28,7 @@ class BaseReader(object): max_seq_len=512, do_lower_case=True, in_tokens=False, + is_inference=False, random_seed=None): self.max_seq_len = max_seq_len self.tokenizer = tokenization.FullTokenizer( @@ -37,6 +38,7 @@ class BaseReader(object): self.cls_id = self.vocab["[CLS]"] self.sep_id = self.vocab["[SEP]"] self.in_tokens = in_tokens + self.is_inference = is_inference np.random.seed(random_seed) @@ -141,25 +143,33 @@ class BaseReader(object): token_ids = tokenizer.convert_tokens_to_ids(tokens) position_ids = list(range(len(token_ids))) - if self.label_map: - label_id = self.label_map[example.label] + if self.is_inference: + Record = namedtuple('Record', + ['token_ids', 'text_type_ids', 'position_ids']) + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids) else: - label_id = example.label - - Record = namedtuple( - 'Record', - ['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid']) - - qid = None - if "qid" in example._fields: - qid = example.qid - - record = Record( - token_ids=token_ids, - text_type_ids=text_type_ids, - position_ids=position_ids, - label_id=label_id, - qid=qid) + if self.label_map: + label_id = self.label_map[example.label] + else: + label_id = example.label + + Record = namedtuple('Record', [ + 'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid' + ]) + + qid = None + if "qid" in example._fields: + qid = example.qid + + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + label_id=label_id, + qid=qid) return record def _prepare_batch_data(self, examples, batch_size, phase=None): @@ -235,14 +245,18 @@ class ClassifyReader(BaseReader): batch_token_ids = [record.token_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records] - batch_labels = [record.label_id for record in batch_records] - batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1]) - if batch_records[0].qid is not None: - batch_qids = [record.qid for record in batch_records] - batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) - else: - batch_qids = np.array([]).astype("int64").reshape([-1, 1]) + if not self.is_inference: + batch_labels = [record.label_id for record in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape( + [-1, 1]) + + if batch_records[0].qid is not None: + batch_qids = [record.qid for record in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape( + [-1, 1]) + else: + batch_qids = np.array([]).astype("int64").reshape([-1, 1]) # padding padded_token_ids, input_mask = pad_batch_data( @@ -254,8 +268,10 @@ class ClassifyReader(BaseReader): return_list = [ padded_token_ids, padded_text_type_ids, padded_position_ids, - input_mask, batch_labels, batch_qids + input_mask ] + if not self.is_inference: + return_list += [batch_labels, batch_qids] return return_list -- GitLab