From 9a89b041ba12ab8bcf9e36e218486b224eae1d33 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 11 Dec 2017 23:23:46 +0800 Subject: [PATCH] Add ChunkEvaluator for multi-batches --- paddle/operators/chunk_eval_op.cc | 18 +++++ paddle/operators/chunk_eval_op.h | 42 +++++++---- python/paddle/v2/fluid/evaluator.py | 73 ++++++++++++++++++- python/paddle/v2/fluid/layers.py | 14 +++- .../tests/book/test_label_semantic_roles.py | 12 +-- .../v2/fluid/tests/test_chunk_eval_op.py | 8 +- 6 files changed, 143 insertions(+), 24 deletions(-) diff --git a/paddle/operators/chunk_eval_op.cc b/paddle/operators/chunk_eval_op.cc index 94127ab33e5..ff2a08ccac9 100644 --- a/paddle/operators/chunk_eval_op.cc +++ b/paddle/operators/chunk_eval_op.cc @@ -32,6 +32,13 @@ class ChunkEvalOp : public framework::OperatorWithKernel { "Output(Recall) of ChunkEvalOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("F1-Score"), "Output(F1-Score) of ChunkEvalOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("NumInferChunks"), + "Output(NumInferChunks) of ChunkEvalOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("NumLabelChunks"), + "Output(NumLabelChunks) of ChunkEvalOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("NumCorrectChunks"), + "Output(NumCorrectChunks) of ChunkEvalOp should not be null."); auto inference_dim = ctx->GetInputDim("Inference"); auto label_dim = ctx->GetInputDim("Label"); @@ -42,6 +49,9 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Precision", {1}); ctx->SetOutputDim("Recall", {1}); ctx->SetOutputDim("F1-Score", {1}); + ctx->SetOutputDim("NumInferChunks", {1}); + ctx->SetOutputDim("NumLabelChunks", {1}); + ctx->SetOutputDim("NumCorrectChunks", {1}); } protected: @@ -70,6 +80,14 @@ class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { "sensitivity) of chunks on the given mini-batch."); AddOutput("F1-Score", "(float). The evaluated F1-Score on the given mini-batch."); + AddOutput( + "NumInferChunks", + "(int). The number of chunks in Inference on the given mini-batch."); + AddOutput("NumLabelChunks", + "(int). The number of chunks in Label on the given mini-batch."); + AddOutput("NumCorrectChunks", + "(int). The number of chunks both in Inference and Label on the " + "given mini-batch."); AddAttr("num_chunk_types", "(int). The number of chunk type. See below for details."); AddAttr( diff --git a/paddle/operators/chunk_eval_op.h b/paddle/operators/chunk_eval_op.h index dd88f2553bb..5d703333c74 100644 --- a/paddle/operators/chunk_eval_op.h +++ b/paddle/operators/chunk_eval_op.h @@ -111,9 +111,7 @@ class ChunkEvalKernel : public framework::OpKernel { std::vector label_segments; std::vector output_segments; std::set excluded_chunk_types; - int64_t num_output_segments = 0; - int64_t num_label_segments = 0; - int64_t num_correct = 0; + if (context.Attr("chunk_scheme") == "IOB") { num_tag_types = 2; tag_begin = 0; @@ -151,12 +149,24 @@ class ChunkEvalKernel : public framework::OpKernel { auto* precision = context.Output("Precision"); auto* recall = context.Output("Recall"); auto* f1 = context.Output("F1-Score"); + auto* num_infer_chunks = context.Output("NumInferChunks"); + auto* num_label_chunks = context.Output("NumLabelChunks"); + auto* num_correct_chunks = context.Output("NumCorrectChunks"); const int64_t* inference_data = inference->data(); const int64_t* label_data = label->data(); T* precision_data = precision->mutable_data(context.GetPlace()); T* racall_data = recall->mutable_data(context.GetPlace()); T* f1_data = f1->mutable_data(context.GetPlace()); + int64_t* num_infer_chunks_data = + num_infer_chunks->mutable_data(context.GetPlace()); + int64_t* num_label_chunks_data = + num_label_chunks->mutable_data(context.GetPlace()); + int64_t* num_correct_chunks_data = + num_correct_chunks->mutable_data(context.GetPlace()); + *num_infer_chunks_data = 0; + *num_label_chunks_data = 0; + *num_correct_chunks_data = 0; auto lod = label->lod(); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); @@ -166,17 +176,23 @@ class ChunkEvalKernel : public framework::OpKernel { for (int i = 0; i < num_sequences; ++i) { int seq_length = lod[0][i + 1] - lod[0][i]; EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i], seq_length, - output_segments, label_segments, num_output_segments, - num_label_segments, num_correct, num_chunk_types, - num_tag_types, other_chunk_type, tag_begin, tag_inside, - tag_end, tag_single, excluded_chunk_types); + output_segments, label_segments, *num_infer_chunks_data, + *num_label_chunks_data, *num_correct_chunks_data, + num_chunk_types, num_tag_types, other_chunk_type, tag_begin, + tag_inside, tag_end, tag_single, excluded_chunk_types); } - *precision_data = !num_output_segments ? 0 : static_cast(num_correct) / - num_output_segments; - *racall_data = !num_label_segments ? 0 : static_cast(num_correct) / - num_label_segments; - *f1_data = !num_correct ? 0 : 2 * (*precision_data) * (*racall_data) / - ((*precision_data) + (*racall_data)); + *precision_data = !(*num_infer_chunks_data) + ? 0 + : static_cast(*num_correct_chunks_data) / + (*num_infer_chunks_data); + *racall_data = !(*num_label_chunks_data) + ? 0 + : static_cast(*num_correct_chunks_data) / + (*num_label_chunks_data); + *f1_data = !(*num_correct_chunks_data) + ? 0 + : 2 * (*precision_data) * (*racall_data) / + ((*precision_data) + (*racall_data)); } void EvalOneSeq(const int64_t* output, const int64_t* label, int length, diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 137c5736226..2d23ff0a166 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -4,7 +4,7 @@ import layers from framework import Program, unique_name, Variable from layer_helper import LayerHelper -__all__ = ['Accuracy'] +__all__ = ['Accuracy', 'ChunkEvaluator'] def _clone_var_(block, var): @@ -132,3 +132,74 @@ class Accuracy(Evaluator): correct = layers.cast(correct, dtype='float32', **kwargs) out = layers.elementwise_div(x=correct, y=total, **kwargs) return np.array(executor.run(eval_program, fetch_list=[out])[0]) + + +class ChunkEvaluator(Evaluator): + """ + Accumulate counter numbers output by chunk_eval from mini-batches and + compute the precision recall and F1-score using the accumulated counter + numbers. + """ + + def __init__(self, + input, + label, + chunk_scheme, + num_chunk_types, + excluded_chunk_types=None, + **kwargs): + super(ChunkEvaluator, self).__init__("chunk_eval", **kwargs) + main_program = self.helper.main_program + if main_program.current_block().idx != 0: + raise ValueError("You can only invoke Evaluator in root block") + + self.num_infer_chunks = self.create_state( + dtype='int64', shape=[1], suffix='num_infer_chunks') + self.num_label_chunks = self.create_state( + dtype='int64', shape=[1], suffix='num_label_chunks') + self.num_correct_chunks = self.create_state( + dtype='int64', shape=[1], suffix='num_correct_chunks') + kwargs = {'main_program': main_program} + precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks = layers.chunk_eval( + input=input, + label=label, + chunk_scheme=chunk_scheme, + num_chunk_types=num_chunk_types, + excluded_chunk_types=excluded_chunk_types, + **kwargs) + layers.sums( + input=[self.num_infer_chunks, num_infer_chunks], + out=self.num_infer_chunks, + **kwargs) + layers.sums( + input=[self.num_label_chunks, num_label_chunks], + out=self.num_label_chunks, + **kwargs) + layers.sums( + input=[self.num_correct_chunks, num_correct_chunks], + out=self.num_correct_chunks, + **kwargs) + + self.metrics.extend([precision, recall, f1_score]) + + def eval(self, executor, eval_program=None): + if eval_program is None: + eval_program = Program() + block = eval_program.current_block() + kwargs = {'main_program': eval_program} + num_infer_chunks, num_label_chunks, num_correct_chunks = executor.run( + eval_program, + fetch_list=[_clone_var_(block, state) for state in self.states]) + num_infer_chunks = num_infer_chunks[0] + num_label_chunks = num_label_chunks[0] + num_correct_chunks = num_correct_chunks[0] + precision = float( + num_correct_chunks) / num_infer_chunks if num_infer_chunks else 0 + recall = float( + num_correct_chunks) / num_label_chunks if num_label_chunks else 0 + f1_score = float(2 * precision * recall) / ( + precision + recall) if num_correct_chunks else 0 + return np.array( + [precision], dtype='float32'), np.array( + [recall], dtype='float32'), np.array( + [f1_score], dtype='float32') diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index 98a04ea9c25..809534884be 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -640,8 +640,8 @@ def chunk_eval(input, excluded_chunk_types=None, **kwargs): """ - This function computes the accuracy using the input and label. - The output is the top_k inputs and their indices. + This function computes and outputs the precision, recall and + F1-score of chunk detection. """ helper = LayerHelper("chunk_eval", **kwargs) @@ -649,6 +649,9 @@ def chunk_eval(input, precision = helper.create_tmp_variable(dtype="float32") recall = helper.create_tmp_variable(dtype="float32") f1_score = helper.create_tmp_variable(dtype="float32") + num_infer_chunks = helper.create_tmp_variable(dtype="int64") + num_label_chunks = helper.create_tmp_variable(dtype="int64") + num_correct_chunks = helper.create_tmp_variable(dtype="int64") helper.append_op( type="chunk_eval", @@ -657,14 +660,17 @@ def chunk_eval(input, outputs={ "Precision": [precision], "Recall": [recall], - "F1-Score": [f1_score] + "F1-Score": [f1_score], + "NumInferChunks": [num_infer_chunks], + "NumLabelChunks": [num_label_chunks], + "NumCorrectChunks": [num_correct_chunks] }, attrs={ "num_chunk_types": num_chunk_types, 'chunk_scheme': chunk_scheme, 'excluded_chunk_types': excluded_chunk_types or [] }) - return precision, recall, f1_score + return precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks def sequence_conv(input, diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index d2693b602ea..caa51b5df4e 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -150,7 +150,7 @@ def main(): crf_decode = fluid.layers.crf_decoding( input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) - precision, recall, f1_score = fluid.layers.chunk_eval( + chunk_evaluator = fluid.evaluator.ChunkEvaluator( input=crf_decode, label=target, chunk_scheme="IOB", @@ -176,14 +176,16 @@ def main(): batch_id = 0 for pass_id in xrange(PASS_NUM): + chunk_evaluator.reset(exe) for data in train_data(): outs = exe.run(fluid.default_main_program(), feed=feeder.feed(data), - fetch_list=[avg_cost, precision, recall, f1_score]) + fetch_list=[avg_cost] + chunk_evaluator.metrics) + precision, recall, f1_score = chunk_evaluator.eval(exe) avg_cost_val = np.array(outs[0]) - precision_val = np.array(outs[1]) - recall_val = np.array(outs[2]) - f1_score_val = np.array(outs[3]) + precision_val = np.array(precision) + recall_val = np.array(recall) + f1_score_val = np.array(f1_score) if batch_id % 10 == 0: print("avg_cost=" + str(avg_cost_val)) diff --git a/python/paddle/v2/fluid/tests/test_chunk_eval_op.py b/python/paddle/v2/fluid/tests/test_chunk_eval_op.py index 819e65a6534..53bf6f815b8 100644 --- a/python/paddle/v2/fluid/tests/test_chunk_eval_op.py +++ b/python/paddle/v2/fluid/tests/test_chunk_eval_op.py @@ -147,7 +147,13 @@ class TestChunkEvalOp(OpTest): 'Recall': np.asarray( [recall], dtype='float32'), 'F1-Score': np.asarray( - [f1], dtype='float32') + [f1], dtype='float32'), + 'NumInferChunks': np.asarray( + [self.num_infer_chunks], dtype='int64'), + 'NumLabelChunks': np.asarray( + [self.num_label_chunks], dtype='int64'), + 'NumCorrectChunks': np.asarray( + [self.num_correct_chunks], dtype='int64') } def setUp(self): -- GitLab