From bdc832cba10538bbfb345bf4d6748de834af6273 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Mon, 6 Nov 2017 19:26:17 -0800 Subject: [PATCH] "add eval interface" --- paddle/operators/accuracy_op.cc | 4 ++ paddle/operators/accuracy_op.h | 6 +- python/paddle/v2/framework/evaluator.py | 67 ++++++++++++++++--- .../v2/framework/tests/test_accuracy_op.py | 3 +- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 2a2a1e9cfd6..142883d9eae 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -30,6 +30,8 @@ class AccuracyOp : public framework::OperatorWithKernel { "Input (Label) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), "Output (Accuracy) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Correct"), + "Output (Correct) of AccuracyOp should not be null."); auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); @@ -43,6 +45,7 @@ class AccuracyOp : public framework::OperatorWithKernel { " the same as label."); ctx->SetOutputDim("Accuracy", {1}); + ctx->SetOutputDim("Correct", {1}); ctx->ShareLoD("Out", /*->*/ "Accuracy"); } @@ -65,6 +68,7 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Label", "Label of the training data"); // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); + AddOutput("Correct", "The correct samples count of current batch"); AddComment(R"DOC( Accuracy. It will print accuracy rate for classification. diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index 1968b53d19a..cc0ea802f9a 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -42,8 +42,10 @@ class AccuracyKernel : public framework::OpKernel { auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); - float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + float* correct_data = correct->mutable_data(ctx.GetPlace()); + int* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); @@ -68,7 +70,7 @@ class AccuracyKernel : public framework::OpKernel { } } - // FIXME(typhoonzero): we don't accumulate the accuracy for now. + *correct_data = num_correct; *accuracy_data = static_cast(num_correct) / static_cast(num_samples); } diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index 7536aa6ea19..4d305f899bc 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -12,18 +12,35 @@ class Evaluator(object): add increment operator to accumulate the metric states """ - def __init__(self, evaluator_type, **kwargs): + def __init__(self, name, **kwargs): self._states = [] - self._helper = LayerHelper(layer_type=evaluator_type, **kwargs) + self._helper = LayerHelper(layer_type=name, **kwargs) - @staticmethod - def clear(self): + # def _update(self): + # """ + # Updates the internal states througth operator + # """ + # raise NotImplementedError() + + def reset(self): """ - clear metric states at the begin of each pass/user specified batch + Clear metric states at the begin of each pass/user specified batch """ - raise NotImplementedError() + reset_program = Program() + for var in self._states: + zeros = helper.create_tmp_variable(dtype=var.data_type) + self._helper.append_op( + type="fill_constant", + outputs={"Out": [zeros]}, + attrs={ + "shape": var.shape, + "value": 0, + }) + self._helper.append_op( + type="scale", inputs={"X": zeros}, outputs={"Out": var}) + return reset_program - def evaluate(self): + def eval(self): """ Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. """ @@ -31,6 +48,10 @@ class Evaluator(object): class Accuracy(Evaluator): + """ + Accuracy need two state variable Total, Correct + """ + def __init__(self, input, label, k=1, **kwargs): super(Accuracy, self).__init__("accuracy", **kwargs) g_total = helper.create_global_variable( @@ -43,6 +64,8 @@ class Accuracy(Evaluator): persistable=True, dtype="int64", shape=[1]) + self._states.append(g_total) + self._states.append(g_correct) topk_out = helper.create_tmp_variable(dtype=input.data_type) topk_indices = helper.create_tmp_variable(dtype="int64") @@ -61,10 +84,34 @@ class Accuracy(Evaluator): "Indices": [topk_indices], "Label": [label] }, - outputs={"Accuracy": [acc_out]}) + outputs={ + "Accuracy": [acc_out], + "Correct": [tp_out], + }) helper.append_op( - type="sum", inputs={"X": [g_total, ], }, + type="sum", + inputs={"X": [g_total, tp_out]}, outputs={"Out": [g_total]}) - return acc_out + + def eval(self): + eval_program = Program() + g_total = self._program + + +# This is demo for composing low level op to compute metric +class F1(Evaluator): + def __init__(self, input, label, **kwargs): + super(F1, self).__init__("F1", **kwargs) + super(Accuracy, self).__init__("accuracy", **kwargs) + g_total = helper.create_global_variable( + name=unique_name("Total"), + persistable=True, + dtype="int64", + shape=[1]) + g_correct = helper.create_global_variable( + name=unique_name("Correct"), + persistable=True, + dtype="int64", + shape=[1]) diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index 6536c297e8e..8674f7523d4 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -18,7 +18,8 @@ class TestAccuracyOp(OpTest): num_correct += 1 break self.outputs = { - 'Accuracy': np.array([num_correct / float(n)]).astype("float32") + 'Accuracy': np.array([num_correct / float(n)]).astype("float32"), + 'Correct': np.array([num_correct]).astype("int32") } def test_check_output(self): -- GitLab