提交 bdc832cb 编写于 作者: D Dong Zhihong

"add eval interface"

上级 233a305b
......@@ -30,6 +30,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
"Input (Label) of accuracy op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
"Output (Accuracy) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Correct"),
"Output (Correct) of AccuracyOp should not be null.");
auto inference_dim = ctx->GetInputDim("Out");
auto label_dim = ctx->GetInputDim("Label");
......@@ -43,6 +45,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
" the same as label.");
ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
ctx->ShareLoD("Out", /*->*/ "Accuracy");
}
......@@ -65,6 +68,7 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label", "Label of the training data");
// TODO(typhoonzero): AddInput("Weight", ...
AddOutput("Accuracy", "The accuracy of current batch");
AddOutput("Correct", "The correct samples count of current batch");
AddComment(R"DOC(
Accuracy. It will print accuracy rate for classification.
......
......@@ -42,8 +42,10 @@ class AccuracyKernel : public framework::OpKernel<T> {
auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy");
auto* correct = ctx.Output<Tensor>("Correct");
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
float* correct_data = correct->mutable_data<float>(ctx.GetPlace());
int* accuracy_data = accuracy->mutable_data<int>(ctx.GetPlace());
const int64_t* indices_data = indices->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
......@@ -68,7 +70,7 @@ class AccuracyKernel : public framework::OpKernel<T> {
}
}
// FIXME(typhoonzero): we don't accumulate the accuracy for now.
*correct_data = num_correct;
*accuracy_data =
static_cast<float>(num_correct) / static_cast<float>(num_samples);
}
......
......@@ -12,18 +12,35 @@ class Evaluator(object):
add increment operator to accumulate the metric states
"""
def __init__(self, evaluator_type, **kwargs):
def __init__(self, name, **kwargs):
self._states = []
self._helper = LayerHelper(layer_type=evaluator_type, **kwargs)
self._helper = LayerHelper(layer_type=name, **kwargs)
@staticmethod
def clear(self):
# def _update(self):
# """
# Updates the internal states througth operator
# """
# raise NotImplementedError()
def reset(self):
"""
clear metric states at the begin of each pass/user specified batch
Clear metric states at the begin of each pass/user specified batch
"""
raise NotImplementedError()
reset_program = Program()
for var in self._states:
zeros = helper.create_tmp_variable(dtype=var.data_type)
self._helper.append_op(
type="fill_constant",
outputs={"Out": [zeros]},
attrs={
"shape": var.shape,
"value": 0,
})
self._helper.append_op(
type="scale", inputs={"X": zeros}, outputs={"Out": var})
return reset_program
def evaluate(self):
def eval(self):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
"""
......@@ -31,6 +48,10 @@ class Evaluator(object):
class Accuracy(Evaluator):
"""
Accuracy need two state variable Total, Correct
"""
def __init__(self, input, label, k=1, **kwargs):
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
......@@ -43,6 +64,8 @@ class Accuracy(Evaluator):
persistable=True,
dtype="int64",
shape=[1])
self._states.append(g_total)
self._states.append(g_correct)
topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_indices = helper.create_tmp_variable(dtype="int64")
......@@ -61,10 +84,34 @@ class Accuracy(Evaluator):
"Indices": [topk_indices],
"Label": [label]
},
outputs={"Accuracy": [acc_out]})
outputs={
"Accuracy": [acc_out],
"Correct": [tp_out],
})
helper.append_op(
type="sum", inputs={"X": [g_total, ], },
type="sum",
inputs={"X": [g_total, tp_out]},
outputs={"Out": [g_total]})
return acc_out
def eval(self):
eval_program = Program()
g_total = self._program
# This is demo for composing low level op to compute metric
class F1(Evaluator):
def __init__(self, input, label, **kwargs):
super(F1, self).__init__("F1", **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
......@@ -18,7 +18,8 @@ class TestAccuracyOp(OpTest):
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32")
'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
'Correct': np.array([num_correct]).astype("int32")
}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册