提交 63ce906b 编写于 作者: G guosheng

Refine ChunkEvalutor by following comments

上级 1eaeacb2
...@@ -80,14 +80,16 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,14 +80,16 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
"sensitivity) of chunks on the given mini-batch."); "sensitivity) of chunks on the given mini-batch.");
AddOutput("F1-Score", AddOutput("F1-Score",
"(float). The evaluated F1-Score on the given mini-batch."); "(float). The evaluated F1-Score on the given mini-batch.");
AddOutput("NumInferChunks",
"(int64_t). The number of chunks in Inference on the given "
"mini-batch.");
AddOutput( AddOutput(
"NumInferChunks", "NumLabelChunks",
"(int). The number of chunks in Inference on the given mini-batch."); "(int64_t). The number of chunks in Label on the given mini-batch.");
AddOutput("NumLabelChunks", AddOutput(
"(int). The number of chunks in Label on the given mini-batch."); "NumCorrectChunks",
AddOutput("NumCorrectChunks", "(int64_t). The number of chunks both in Inference and Label on the "
"(int). The number of chunks both in Inference and Label on the " "given mini-batch.");
"given mini-batch.");
AddAttr<int>("num_chunk_types", AddAttr<int>("num_chunk_types",
"(int). The number of chunk type. See below for details."); "(int). The number of chunk type. See below for details.");
AddAttr<std::string>( AddAttr<std::string>(
......
...@@ -178,20 +178,19 @@ def main(): ...@@ -178,20 +178,19 @@ def main():
for pass_id in xrange(PASS_NUM): for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe) chunk_evaluator.reset(exe)
for data in train_data(): for data in train_data():
outs = exe.run(fluid.default_main_program(), cost, precision, recall, f1_score = exe.run(
feed=feeder.feed(data), fluid.default_main_program(),
fetch_list=[avg_cost] + chunk_evaluator.metrics) feed=feeder.feed(data),
precision, recall, f1_score = chunk_evaluator.eval(exe) fetch_list=[avg_cost] + chunk_evaluator.metrics)
avg_cost_val = np.array(outs[0]) pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(
precision_val = np.array(precision) exe)
recall_val = np.array(recall)
f1_score_val = np.array(f1_score)
if batch_id % 10 == 0: if batch_id % 10 == 0:
print("avg_cost=" + str(avg_cost_val)) print("avg_cost:" + str(cost) + " precision:" + str(
print("precision_val=" + str(precision_val)) precision) + " recall:" + str(recall) + " f1_score:" + str(
print("recall_val:" + str(recall_val)) f1_score) + " pass_precision:" + str(
print("f1_score_val:" + str(f1_score_val)) pass_precision) + " pass_recall:" + str(pass_recall)
+ " pass_f1_score:" + str(pass_f1_score))
# exit early for CI # exit early for CI
exit(0) exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册