提交 9a89b041 编写于 作者: G guosheng

Add ChunkEvaluator for multi-batches

上级 06a3a887
......@@ -32,6 +32,13 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
"Output(Recall) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("F1-Score"),
"Output(F1-Score) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("NumInferChunks"),
"Output(NumInferChunks) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("NumLabelChunks"),
"Output(NumLabelChunks) of ChunkEvalOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("NumCorrectChunks"),
"Output(NumCorrectChunks) of ChunkEvalOp should not be null.");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");
......@@ -42,6 +49,9 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Precision", {1});
ctx->SetOutputDim("Recall", {1});
ctx->SetOutputDim("F1-Score", {1});
ctx->SetOutputDim("NumInferChunks", {1});
ctx->SetOutputDim("NumLabelChunks", {1});
ctx->SetOutputDim("NumCorrectChunks", {1});
}
protected:
......@@ -70,6 +80,14 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
"sensitivity) of chunks on the given mini-batch.");
AddOutput("F1-Score",
"(float). The evaluated F1-Score on the given mini-batch.");
AddOutput(
"NumInferChunks",
"(int). The number of chunks in Inference on the given mini-batch.");
AddOutput("NumLabelChunks",
"(int). The number of chunks in Label on the given mini-batch.");
AddOutput("NumCorrectChunks",
"(int). The number of chunks both in Inference and Label on the "
"given mini-batch.");
AddAttr<int>("num_chunk_types",
"(int). The number of chunk type. See below for details.");
AddAttr<std::string>(
......
......@@ -111,9 +111,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
std::vector<Segment> label_segments;
std::vector<Segment> output_segments;
std::set<int> excluded_chunk_types;
int64_t num_output_segments = 0;
int64_t num_label_segments = 0;
int64_t num_correct = 0;
if (context.Attr<std::string>("chunk_scheme") == "IOB") {
num_tag_types = 2;
tag_begin = 0;
......@@ -151,12 +149,24 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
auto* precision = context.Output<Tensor>("Precision");
auto* recall = context.Output<Tensor>("Recall");
auto* f1 = context.Output<Tensor>("F1-Score");
auto* num_infer_chunks = context.Output<Tensor>("NumInferChunks");
auto* num_label_chunks = context.Output<Tensor>("NumLabelChunks");
auto* num_correct_chunks = context.Output<Tensor>("NumCorrectChunks");
const int64_t* inference_data = inference->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
T* precision_data = precision->mutable_data<T>(context.GetPlace());
T* racall_data = recall->mutable_data<T>(context.GetPlace());
T* f1_data = f1->mutable_data<T>(context.GetPlace());
int64_t* num_infer_chunks_data =
num_infer_chunks->mutable_data<int64_t>(context.GetPlace());
int64_t* num_label_chunks_data =
num_label_chunks->mutable_data<int64_t>(context.GetPlace());
int64_t* num_correct_chunks_data =
num_correct_chunks->mutable_data<int64_t>(context.GetPlace());
*num_infer_chunks_data = 0;
*num_label_chunks_data = 0;
*num_correct_chunks_data = 0;
auto lod = label->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
......@@ -166,17 +176,23 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
for (int i = 0; i < num_sequences; ++i) {
int seq_length = lod[0][i + 1] - lod[0][i];
EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length,
output_segments, label_segments, num_output_segments,
num_label_segments, num_correct, num_chunk_types,
num_tag_types, other_chunk_type, tag_begin, tag_inside,
tag_end, tag_single, excluded_chunk_types);
output_segments, label_segments, *num_infer_chunks_data,
*num_label_chunks_data, *num_correct_chunks_data,
num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
tag_inside, tag_end, tag_single, excluded_chunk_types);
}
*precision_data = !num_output_segments ? 0 : static_cast<T>(num_correct) /
num_output_segments;
*racall_data = !num_label_segments ? 0 : static_cast<T>(num_correct) /
num_label_segments;
*f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) /
((*precision_data) + (*racall_data));
*precision_data = !(*num_infer_chunks_data)
? 0
: static_cast<T>(*num_correct_chunks_data) /
(*num_infer_chunks_data);
*racall_data = !(*num_label_chunks_data)
? 0
: static_cast<T>(*num_correct_chunks_data) /
(*num_label_chunks_data);
*f1_data = !(*num_correct_chunks_data)
? 0
: 2 * (*precision_data) * (*racall_data) /
((*precision_data) + (*racall_data));
}
void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
......
......@@ -4,7 +4,7 @@ import layers
from framework import Program, unique_name, Variable
from layer_helper import LayerHelper
__all__ = ['Accuracy']
__all__ = ['Accuracy', 'ChunkEvaluator']
def _clone_var_(block, var):
......@@ -132,3 +132,74 @@ class Accuracy(Evaluator):
correct = layers.cast(correct, dtype='float32', **kwargs)
out = layers.elementwise_div(x=correct, y=total, **kwargs)
return np.array(executor.run(eval_program, fetch_list=[out])[0])
class ChunkEvaluator(Evaluator):
"""
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
numbers.
"""
def __init__(self,
input,
label,
chunk_scheme,
num_chunk_types,
excluded_chunk_types=None,
**kwargs):
super(ChunkEvaluator, self).__init__("chunk_eval", **kwargs)
main_program = self.helper.main_program
if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block")
self.num_infer_chunks = self.create_state(
dtype='int64', shape=[1], suffix='num_infer_chunks')
self.num_label_chunks = self.create_state(
dtype='int64', shape=[1], suffix='num_label_chunks')
self.num_correct_chunks = self.create_state(
dtype='int64', shape=[1], suffix='num_correct_chunks')
kwargs = {'main_program': main_program}
precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval(
input=input,
label=label,
chunk_scheme=chunk_scheme,
num_chunk_types=num_chunk_types,
excluded_chunk_types=excluded_chunk_types,
**kwargs)
layers.sums(
input=[self.num_infer_chunks, num_infer_chunks],
out=self.num_infer_chunks,
**kwargs)
layers.sums(
input=[self.num_label_chunks, num_label_chunks],
out=self.num_label_chunks,
**kwargs)
layers.sums(
input=[self.num_correct_chunks, num_correct_chunks],
out=self.num_correct_chunks,
**kwargs)
self.metrics.extend([precision, recall, f1_score])
def eval(self, executor, eval_program=None):
if eval_program is None:
eval_program = Program()
block = eval_program.current_block()
kwargs = {'main_program': eval_program}
num_infer_chunks, num_label_chunks, num_correct_chunks = executor.run(
eval_program,
fetch_list=[_clone_var_(block, state) for state in self.states])
num_infer_chunks = num_infer_chunks[0]
num_label_chunks = num_label_chunks[0]
num_correct_chunks = num_correct_chunks[0]
precision = float(
num_correct_chunks) / num_infer_chunks if num_infer_chunks else 0
recall = float(
num_correct_chunks) / num_label_chunks if num_label_chunks else 0
f1_score = float(2 * precision * recall) / (
precision + recall) if num_correct_chunks else 0
return np.array(
[precision], dtype='float32'), np.array(
[recall], dtype='float32'), np.array(
[f1_score], dtype='float32')
......@@ -640,8 +640,8 @@ def chunk_eval(input,
excluded_chunk_types=None,
**kwargs):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
This function computes and outputs the precision, recall and
F1-score of chunk detection.
"""
helper = LayerHelper("chunk_eval", **kwargs)
......@@ -649,6 +649,9 @@ def chunk_eval(input,
precision = helper.create_tmp_variable(dtype="float32")
recall = helper.create_tmp_variable(dtype="float32")
f1_score = helper.create_tmp_variable(dtype="float32")
num_infer_chunks = helper.create_tmp_variable(dtype="int64")
num_label_chunks = helper.create_tmp_variable(dtype="int64")
num_correct_chunks = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="chunk_eval",
......@@ -657,14 +660,17 @@ def chunk_eval(input,
outputs={
"Precision": [precision],
"Recall": [recall],
"F1-Score": [f1_score]
"F1-Score": [f1_score],
"NumInferChunks": [num_infer_chunks],
"NumLabelChunks": [num_label_chunks],
"NumCorrectChunks": [num_correct_chunks]
},
attrs={
"num_chunk_types": num_chunk_types,
'chunk_scheme': chunk_scheme,
'excluded_chunk_types': excluded_chunk_types or []
})
return precision, recall, f1_score
return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks
def sequence_conv(input,
......
......@@ -150,7 +150,7 @@ def main():
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
precision, recall, f1_score = fluid.layers.chunk_eval(
chunk_evaluator = fluid.evaluator.ChunkEvaluator(
input=crf_decode,
label=target,
chunk_scheme="IOB",
......@@ -176,14 +176,16 @@ def main():
batch_id = 0
for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
for data in train_data():
outs = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, precision, recall, f1_score])
fetch_list=[avg_cost] + chunk_evaluator.metrics)
precision, recall, f1_score = chunk_evaluator.eval(exe)
avg_cost_val = np.array(outs[0])
precision_val = np.array(outs[1])
recall_val = np.array(outs[2])
f1_score_val = np.array(outs[3])
precision_val = np.array(precision)
recall_val = np.array(recall)
f1_score_val = np.array(f1_score)
if batch_id % 10 == 0:
print("avg_cost=" + str(avg_cost_val))
......
......@@ -147,7 +147,13 @@ class TestChunkEvalOp(OpTest):
'Recall': np.asarray(
[recall], dtype='float32'),
'F1-Score': np.asarray(
[f1], dtype='float32')
[f1], dtype='float32'),
'NumInferChunks': np.asarray(
[self.num_infer_chunks], dtype='int64'),
'NumLabelChunks': np.asarray(
[self.num_label_chunks], dtype='int64'),
'NumCorrectChunks': np.asarray(
[self.num_correct_chunks], dtype='int64')
}
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册