提交 bdc832cb 编写于 作者: D Dong Zhihong

"add eval interface"

上级 233a305b
...@@ -30,6 +30,8 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
"Input (Label) of accuracy op should not be null."); "Input (Label) of accuracy op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
"Output (Accuracy) of AccuracyOp should not be null."); "Output (Accuracy) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Correct"),
"Output (Correct) of AccuracyOp should not be null.");
auto inference_dim = ctx->GetInputDim("Out"); auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label"); auto label_dim = ctx->GetInputDim("Label");
...@@ -43,6 +45,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -43,6 +45,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
" the same as label."); " the same as label.");
ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
ctx->ShareLoD("Out", /*->*/ "Accuracy"); ctx->ShareLoD("Out", /*->*/ "Accuracy");
} }
...@@ -65,6 +68,7 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,6 +68,7 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", "Label of the training data"); AddInput("Label", "Label of the training data");
// TODO(typhoonzero): AddInput("Weight", ... // TODO(typhoonzero): AddInput("Weight", ...
AddOutput("Accuracy", "The accuracy of current batch"); AddOutput("Accuracy", "The accuracy of current batch");
AddOutput("Correct", "The correct samples count of current batch");
AddComment(R"DOC( AddComment(R"DOC(
Accuracy. It will print accuracy rate for classification. Accuracy. It will print accuracy rate for classification.
......
...@@ -42,8 +42,10 @@ class AccuracyKernel : public framework::OpKernel<T> { ...@@ -42,8 +42,10 @@ class AccuracyKernel : public framework::OpKernel<T> {
auto* indices = ctx.Input<Tensor>("Indices"); auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy"); auto* accuracy = ctx.Output<Tensor>("Accuracy");
auto* correct = ctx.Output<Tensor>("Correct");
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace()); float* correct_data = correct->mutable_data<float>(ctx.GetPlace());
int* accuracy_data = accuracy->mutable_data<int>(ctx.GetPlace());
const int64_t* indices_data = indices->data<int64_t>(); const int64_t* indices_data = indices->data<int64_t>();
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
...@@ -68,7 +70,7 @@ class AccuracyKernel : public framework::OpKernel<T> { ...@@ -68,7 +70,7 @@ class AccuracyKernel : public framework::OpKernel<T> {
} }
} }
// FIXME(typhoonzero): we don't accumulate the accuracy for now. *correct_data = num_correct;
*accuracy_data = *accuracy_data =
static_cast<float>(num_correct) / static_cast<float>(num_samples); static_cast<float>(num_correct) / static_cast<float>(num_samples);
} }
......
...@@ -12,18 +12,35 @@ class Evaluator(object): ...@@ -12,18 +12,35 @@ class Evaluator(object):
add increment operator to accumulate the metric states add increment operator to accumulate the metric states
""" """
def __init__(self, evaluator_type, **kwargs): def __init__(self, name, **kwargs):
self._states = [] self._states = []
self._helper = LayerHelper(layer_type=evaluator_type, **kwargs) self._helper = LayerHelper(layer_type=name, **kwargs)
@staticmethod # def _update(self):
def clear(self): # """
# Updates the internal states througth operator
# """
# raise NotImplementedError()
def reset(self):
""" """
clear metric states at the begin of each pass/user specified batch Clear metric states at the begin of each pass/user specified batch
""" """
raise NotImplementedError() reset_program = Program()
for var in self._states:
zeros = helper.create_tmp_variable(dtype=var.data_type)
self._helper.append_op(
type="fill_constant",
outputs={"Out": [zeros]},
attrs={
"shape": var.shape,
"value": 0,
})
self._helper.append_op(
type="scale", inputs={"X": zeros}, outputs={"Out": var})
return reset_program
def evaluate(self): def eval(self):
""" """
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
""" """
...@@ -31,6 +48,10 @@ class Evaluator(object): ...@@ -31,6 +48,10 @@ class Evaluator(object):
class Accuracy(Evaluator): class Accuracy(Evaluator):
"""
Accuracy need two state variable Total, Correct
"""
def __init__(self, input, label, k=1, **kwargs): def __init__(self, input, label, k=1, **kwargs):
super(Accuracy, self).__init__("accuracy", **kwargs) super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable( g_total = helper.create_global_variable(
...@@ -43,6 +64,8 @@ class Accuracy(Evaluator): ...@@ -43,6 +64,8 @@ class Accuracy(Evaluator):
persistable=True, persistable=True,
dtype="int64", dtype="int64",
shape=[1]) shape=[1])
self._states.append(g_total)
self._states.append(g_correct)
topk_out = helper.create_tmp_variable(dtype=input.data_type) topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_indices = helper.create_tmp_variable(dtype="int64") topk_indices = helper.create_tmp_variable(dtype="int64")
...@@ -61,10 +84,34 @@ class Accuracy(Evaluator): ...@@ -61,10 +84,34 @@ class Accuracy(Evaluator):
"Indices": [topk_indices], "Indices": [topk_indices],
"Label": [label] "Label": [label]
}, },
outputs={"Accuracy": [acc_out]}) outputs={
"Accuracy": [acc_out],
"Correct": [tp_out],
})
helper.append_op( helper.append_op(
type="sum", inputs={"X": [g_total, ], }, type="sum",
inputs={"X": [g_total, tp_out]},
outputs={"Out": [g_total]}) outputs={"Out": [g_total]})
return acc_out return acc_out
def eval(self):
eval_program = Program()
g_total = self._program
# This is demo for composing low level op to compute metric
class F1(Evaluator):
def __init__(self, input, label, **kwargs):
super(F1, self).__init__("F1", **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
...@@ -18,7 +18,8 @@ class TestAccuracyOp(OpTest): ...@@ -18,7 +18,8 @@ class TestAccuracyOp(OpTest):
num_correct += 1 num_correct += 1
break break
self.outputs = { self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32") 'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
'Correct': np.array([num_correct]).astype("int32")
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册