提交 2f2fe7af 编写于 作者: T tianxin

do classify_inference by infer_program

fix #207
上级 39d4571d
...@@ -316,4 +316,4 @@ python -u predict_classifier.py \ ...@@ -316,4 +316,4 @@ python -u predict_classifier.py \
实际使用时,需要通过 `init_checkpoint` 指定预测用的模型,通过 `predict_set` 指定待预测的数据文件,通过 `num_labels` 配置分类的类别数目; 实际使用时,需要通过 `init_checkpoint` 指定预测用的模型,通过 `predict_set` 指定待预测的数据文件,通过 `num_labels` 配置分类的类别数目;
**Note**: predict_set 的数据格式与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可; **Note**: predict_set 的数据格式是由 text_a、text_b(可选) 组成的1列/2列 tsv 文件;
...@@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False): ...@@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
if is_prediction: if is_prediction:
probs = fluid.layers.softmax(logits) probs = fluid.layers.softmax(logits)
feed_targets_name = [ feed_targets_name = [
src_ids.name, pos_ids.name, sent_ids.name, input_mask.name src_ids.name, sent_ids.name, pos_ids.name, input_mask.name
] ]
return pyreader, probs, feed_targets_name return pyreader, probs, feed_targets_name
......
...@@ -37,6 +37,7 @@ parser = argparse.ArgumentParser(__doc__) ...@@ -37,6 +37,7 @@ parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "options to init, resume and save model.") model_g = ArgumentGroup(parser, "model", "options to init, resume and save model.")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for bert model config.") model_g.add_arg("ernie_config_path", str, None, "Path to the json file for bert model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("save_inference_model_path", str, "inference_model", "If set, save the inference model to this path.")
model_g.add_arg("use_fp16", bool, False, "Whether to resume parameters from fp16 checkpoint.") model_g.add_arg("use_fp16", bool, False, "Whether to resume parameters from fp16 checkpoint.")
model_g.add_arg("num_labels", int, 2, "num labels for classify") model_g.add_arg("num_labels", int, 2, "num labels for classify")
...@@ -65,7 +66,8 @@ def main(args): ...@@ -65,7 +66,8 @@ def main(args):
label_map_config=args.label_map_config, label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
in_tokens=False) in_tokens=False,
is_inference=True)
predict_prog = fluid.Program() predict_prog = fluid.Program()
predict_startup = fluid.Program() predict_startup = fluid.Program()
...@@ -95,7 +97,25 @@ def main(args): ...@@ -95,7 +97,25 @@ def main(args):
else: else:
raise ValueError("args 'init_checkpoint' should be set for prediction!") raise ValueError("args 'init_checkpoint' should be set for prediction!")
predict_exe = fluid.Executor(place) assert args.save_inference_model_path, "args save_inference_model_path should be set for prediction"
_, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/'))
dir_name = ckpt_dir + '_inference_model'
model_path = os.path.join(args.save_inference_model_path, dir_name)
print("save inference model to %s" % model_path)
fluid.io.save_inference_model(
model_path,
feed_target_names, [probs],
exe,
main_program=predict_prog)
print("load inference model from %s" % model_path)
infer_program, feed_target_names, probs = fluid.io.load_inference_model(
model_path, exe)
src_ids = feed_target_names[0]
sent_ids = feed_target_names[1]
pos_ids = feed_target_names[2]
input_mask = feed_target_names[3]
predict_data_generator = reader.data_generator( predict_data_generator = reader.data_generator(
input_file=args.predict_set, input_file=args.predict_set,
...@@ -103,25 +123,24 @@ def main(args): ...@@ -103,25 +123,24 @@ def main(args):
epoch=1, epoch=1,
shuffle=False) shuffle=False)
predict_pyreader.decorate_tensor_provider(predict_data_generator)
predict_pyreader.start()
all_results = []
time_begin = time.time()
while True:
try:
results = predict_exe.run(program=predict_prog, fetch_list=[probs.name])
all_results.extend(results[0])
except fluid.core.EOFException:
predict_pyreader.reset()
break
time_end = time.time()
np.set_printoptions(precision=4, suppress=True)
print("-------------- prediction results --------------") print("-------------- prediction results --------------")
for index, result in enumerate(all_results): np.set_printoptions(precision=4, suppress=True)
print(str(index) + '\t{}'.format(result)) index = 0
for sample in predict_data_generator():
src_ids_data = sample[0]
sent_ids_data = sample[1]
pos_ids_data = sample[2]
input_mask_data = sample[3]
output = exe.run(
infer_program,
feed={src_ids: src_ids_data,
sent_ids: sent_ids_data,
pos_ids: pos_ids_data,
input_mask: input_mask_data},
fetch_list=probs)
for single_result in output[0]:
print("example_index:{}\t{}".format(index, single_result))
index += 1
if __name__ == '__main__': if __name__ == '__main__':
print_arguments(args) print_arguments(args)
......
...@@ -28,6 +28,7 @@ class BaseReader(object): ...@@ -28,6 +28,7 @@ class BaseReader(object):
max_seq_len=512, max_seq_len=512,
do_lower_case=True, do_lower_case=True,
in_tokens=False, in_tokens=False,
is_inference=False,
random_seed=None): random_seed=None):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
...@@ -37,6 +38,7 @@ class BaseReader(object): ...@@ -37,6 +38,7 @@ class BaseReader(object):
self.cls_id = self.vocab["[CLS]"] self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"] self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens self.in_tokens = in_tokens
self.is_inference = is_inference
np.random.seed(random_seed) np.random.seed(random_seed)
...@@ -141,14 +143,22 @@ class BaseReader(object): ...@@ -141,14 +143,22 @@ class BaseReader(object):
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids))) position_ids = list(range(len(token_ids)))
if self.is_inference:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
else:
if self.label_map: if self.label_map:
label_id = self.label_map[example.label] label_id = self.label_map[example.label]
else: else:
label_id = example.label label_id = example.label
Record = namedtuple( Record = namedtuple('Record', [
'Record', 'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid']) ])
qid = None qid = None
if "qid" in example._fields: if "qid" in example._fields:
...@@ -235,12 +245,16 @@ class ClassifyReader(BaseReader): ...@@ -235,12 +245,16 @@ class ClassifyReader(BaseReader):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records]
if not self.is_inference:
batch_labels = [record.label_id for record in batch_records] batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1]) batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1])
if batch_records[0].qid is not None: if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records] batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) batch_qids = np.array(batch_qids).astype("int64").reshape(
[-1, 1])
else: else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1]) batch_qids = np.array([]).astype("int64").reshape([-1, 1])
...@@ -254,8 +268,10 @@ class ClassifyReader(BaseReader): ...@@ -254,8 +268,10 @@ class ClassifyReader(BaseReader):
return_list = [ return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids, padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, batch_labels, batch_qids input_mask
] ]
if not self.is_inference:
return_list += [batch_labels, batch_qids]
return return_list return return_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册