未验证 提交 34db009c 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #1212 from kuke/fix_sequence_tagging_ner

Use metrics.chunk_evaluator for sequence_taggging_for_ner
......@@ -15,17 +15,23 @@ from utils import logger, load_dict
from utils_extend import to_lodtensor, get_embedding
def test(exe, chunk_evaluator, inference_program, test_data, place):
chunk_evaluator.reset(exe)
def test(exe, chunk_evaluator, inference_program, test_data, test_fetch_list,
place):
chunk_evaluator.reset()
for data in test_data():
word = to_lodtensor([x[0] for x in data], place)
mark = to_lodtensor([x[1] for x in data], place)
target = to_lodtensor([x[2] for x in data], place)
acc = exe.run(inference_program,
rets = exe.run(inference_program,
feed={"word": word,
"mark": mark,
"target": target})
return chunk_evaluator.eval(exe)
"target": target},
fetch_list=test_fetch_list)
num_infer = np.array(rets[0])
num_label = np.array(rets[1])
num_correct = np.array(rets[2])
chunk_evaluator.update(num_infer[0], num_label[0], num_correct[0])
return chunk_evaluator.eval()
def main(train_data_file,
......@@ -58,16 +64,16 @@ def main(train_data_file,
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator()
inference_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(inference_program):
test_target = chunk_evaluator.metrics + chunk_evaluator.states
inference_program = fluid.io.get_inference_program(test_target)
test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks]
if "CE_MODE_X" not in os.environ:
train_reader = paddle.batch(
......@@ -100,26 +106,29 @@ def main(train_data_file,
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(word_vector_values, place)
time_begin = time.time()
for pass_id in six.moves.xrange(num_passes):
chunk_evaluator.reset(exe)
chunk_evaluator.reset()
for batch_id, data in enumerate(train_reader()):
cost, batch_precision, batch_recall, batch_f1_score = exe.run(
cost_var, nums_infer, nums_label, nums_correct = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
fetch_list=[
avg_cost, num_infer_chunks, num_label_chunks,
num_correct_chunks
])
if batch_id % 5 == 0:
print(cost)
print("Pass " + str(pass_id) + ", Batch " + str(
batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str(
batch_precision[0]) + ", Recall " + str(batch_recall[0])
+ ", F1_score" + str(batch_f1_score[0]))
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe)
print("Pass " + str(pass_id) + ", Batch " + str(batch_id) +
", Cost " + str(cost_var[0]))
chunk_evaluator.update(nums_infer, nums_label, nums_correct)
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval()
print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score))
test_pass_precision, test_pass_recall, test_pass_f1_score = test(
exe, chunk_evaluator, inference_program, test_reader, place)
exe, chunk_evaluator, inference_program, test_reader,
test_fetch_list, place)
print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
test_pass_precision) + " pass_recall:" + str(test_pass_recall) +
" pass_f1_score:" + str(test_pass_f1_score))
......@@ -128,12 +137,10 @@ def main(train_data_file,
fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'],
crf_decode, exe)
if ("CE_MODE_X" in os.environ) and (pass_id % 50 == 0):
if pass_id > 0:
if "CE_MODE_X" in os.environ:
print("kpis train_precision %f" % pass_precision)
print("kpis test_precision %f" % test_pass_precision)
print("kpis train_duration %f" % (time.time() - time_begin))
time_begin = time.time()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册