提交 2f2fe7af 编写于 作者: T tianxin

do classify_inference by infer_program

fix #207
上级 39d4571d
......@@ -316,4 +316,4 @@ python -u predict_classifier.py \
实际使用时,需要通过 `init_checkpoint` 指定预测用的模型,通过 `predict_set` 指定待预测的数据文件,通过 `num_labels` 配置分类的类别数目;
**Note**: predict_set 的数据格式与 dev_set 和 test_set 的数据格式完全一致,是由 text_a、text_b(可选) 、label 组成的2列/3列 tsv 文件,predict_set 中的 label 列起到占位符的作用,全部置 0 即可;
**Note**: predict_set 的数据格式是由 text_a、text_b(可选) 组成的1列/2列 tsv 文件;
......@@ -65,7 +65,7 @@ def create_model(args, pyreader_name, ernie_config, is_prediction=False):
if is_prediction:
probs = fluid.layers.softmax(logits)
feed_targets_name = [
src_ids.name, pos_ids.name, sent_ids.name, input_mask.name
src_ids.name, sent_ids.name, pos_ids.name, input_mask.name
]
return pyreader, probs, feed_targets_name
......
......@@ -37,6 +37,7 @@ parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "options to init, resume and save model.")
model_g.add_arg("ernie_config_path", str, None, "Path to the json file for bert model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("save_inference_model_path", str, "inference_model", "If set, save the inference model to this path.")
model_g.add_arg("use_fp16", bool, False, "Whether to resume parameters from fp16 checkpoint.")
model_g.add_arg("num_labels", int, 2, "num labels for classify")
......@@ -65,7 +66,8 @@ def main(args):
label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
in_tokens=False)
in_tokens=False,
is_inference=True)
predict_prog = fluid.Program()
predict_startup = fluid.Program()
......@@ -95,7 +97,25 @@ def main(args):
else:
raise ValueError("args 'init_checkpoint' should be set for prediction!")
predict_exe = fluid.Executor(place)
assert args.save_inference_model_path, "args save_inference_model_path should be set for prediction"
_, ckpt_dir = os.path.split(args.init_checkpoint.rstrip('/'))
dir_name = ckpt_dir + '_inference_model'
model_path = os.path.join(args.save_inference_model_path, dir_name)
print("save inference model to %s" % model_path)
fluid.io.save_inference_model(
model_path,
feed_target_names, [probs],
exe,
main_program=predict_prog)
print("load inference model from %s" % model_path)
infer_program, feed_target_names, probs = fluid.io.load_inference_model(
model_path, exe)
src_ids = feed_target_names[0]
sent_ids = feed_target_names[1]
pos_ids = feed_target_names[2]
input_mask = feed_target_names[3]
predict_data_generator = reader.data_generator(
input_file=args.predict_set,
......@@ -103,25 +123,24 @@ def main(args):
epoch=1,
shuffle=False)
predict_pyreader.decorate_tensor_provider(predict_data_generator)
predict_pyreader.start()
all_results = []
time_begin = time.time()
while True:
try:
results = predict_exe.run(program=predict_prog, fetch_list=[probs.name])
all_results.extend(results[0])
except fluid.core.EOFException:
predict_pyreader.reset()
break
time_end = time.time()
np.set_printoptions(precision=4, suppress=True)
print("-------------- prediction results --------------")
for index, result in enumerate(all_results):
print(str(index) + '\t{}'.format(result))
np.set_printoptions(precision=4, suppress=True)
index = 0
for sample in predict_data_generator():
src_ids_data = sample[0]
sent_ids_data = sample[1]
pos_ids_data = sample[2]
input_mask_data = sample[3]
output = exe.run(
infer_program,
feed={src_ids: src_ids_data,
sent_ids: sent_ids_data,
pos_ids: pos_ids_data,
input_mask: input_mask_data},
fetch_list=probs)
for single_result in output[0]:
print("example_index:{}\t{}".format(index, single_result))
index += 1
if __name__ == '__main__':
print_arguments(args)
......
......@@ -28,6 +28,7 @@ class BaseReader(object):
max_seq_len=512,
do_lower_case=True,
in_tokens=False,
is_inference=False,
random_seed=None):
self.max_seq_len = max_seq_len
self.tokenizer = tokenization.FullTokenizer(
......@@ -37,6 +38,7 @@ class BaseReader(object):
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.in_tokens = in_tokens
self.is_inference = is_inference
np.random.seed(random_seed)
......@@ -141,25 +143,33 @@ class BaseReader(object):
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
if self.label_map:
label_id = self.label_map[example.label]
if self.is_inference:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
else:
label_id = example.label
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'])
qid = None
if "qid" in example._fields:
qid = example.qid
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
Record = namedtuple('Record', [
'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'
])
qid = None
if "qid" in example._fields:
qid = example.qid
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record
def _prepare_batch_data(self, examples, batch_size, phase=None):
......@@ -235,14 +245,18 @@ class ClassifyReader(BaseReader):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])
if not self.is_inference:
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1])
if batch_records[0].qid is not None:
batch_qids = [record.qid for record in batch_records]
batch_qids = np.array(batch_qids).astype("int64").reshape(
[-1, 1])
else:
batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data(
......@@ -254,8 +268,10 @@ class ClassifyReader(BaseReader):
return_list = [
padded_token_ids, padded_text_type_ids, padded_position_ids,
input_mask, batch_labels, batch_qids
input_mask
]
if not self.is_inference:
return_list += [batch_labels, batch_qids]
return return_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册