提交 63ce906b 编写于 作者: G guosheng

Refine ChunkEvalutor by following comments

上级 1eaeacb2
......@@ -80,14 +80,16 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
"sensitivity) of chunks on the given mini-batch.");
AddOutput("F1-Score",
"(float). The evaluated F1-Score on the given mini-batch.");
AddOutput("NumInferChunks",
"(int64_t). The number of chunks in Inference on the given "
"mini-batch.");
AddOutput(
"NumInferChunks",
"(int). The number of chunks in Inference on the given mini-batch.");
AddOutput("NumLabelChunks",
"(int). The number of chunks in Label on the given mini-batch.");
AddOutput("NumCorrectChunks",
"(int). The number of chunks both in Inference and Label on the "
"given mini-batch.");
"NumLabelChunks",
"(int64_t). The number of chunks in Label on the given mini-batch.");
AddOutput(
"NumCorrectChunks",
"(int64_t). The number of chunks both in Inference and Label on the "
"given mini-batch.");
AddAttr<int>("num_chunk_types",
"(int). The number of chunk type. See below for details.");
AddAttr<std::string>(
......
......@@ -178,20 +178,19 @@ def main():
for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
for data in train_data():
outs = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
precision, recall, f1_score = chunk_evaluator.eval(exe)
avg_cost_val = np.array(outs[0])
precision_val = np.array(precision)
recall_val = np.array(recall)
f1_score_val = np.array(f1_score)
cost, precision, recall, f1_score = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(
exe)
if batch_id % 10 == 0:
print("avg_cost=" + str(avg_cost_val))
print("precision_val=" + str(precision_val))
print("recall_val:" + str(recall_val))
print("f1_score_val:" + str(f1_score_val))
print("avg_cost:" + str(cost) + " precision:" + str(
precision) + " recall:" + str(recall) + " f1_score:" + str(
f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall)
+ " pass_f1_score:" + str(pass_f1_score))
# exit early for CI
exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册