提交 65dbbd57 编写于 作者: Y yangyaming

Add and pass unittests.

上级 06c7c8c8
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/precision_recall_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
if (ctx->HasInput("Weights")) { if (ctx->HasInput("Weights")) {
auto weights_dims = ctx->GetInputDim("Weights"); auto weights_dims = ctx->GetInputDim("Weights");
PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1}, PADDLE_ENFORCE_EQ(weights_dims,
framework::make_ddim({predictions_dims[0], 1}),
"The shape of Input(Weights) should be " "The shape of Input(Weights) should be "
"[batch_size, 1]."); "[batch_size, 1].");
} }
if (ctx->HasInput("StatesInfo")) { if (ctx->HasInput("StatesInfo")) {
auto states_dims = ctx->GetInputDim("StatesInfo"); auto states_dims = ctx->GetInputDim("StatesInfo");
PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4}, PADDLE_ENFORCE_EQ(states_dims,
framework::make_ddim({predictions_dims[1], 4}),
"The shape of Input(StatesInfo) should be " "The shape of Input(StatesInfo) should be "
"[class_number, 4]."); "[class_number, 4].");
} }
...@@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
// [ TP, FP, TN, FN ] // [ TP, FP, TN, FN ]
ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4});
} }
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Predictions")->type());
}
}; };
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
"provided, current state will be accumulated to this state and " "provided, current state will be accumulated to this state and "
"the accumulation state will be as the output state.") "the accumulation state will be as the output state.")
.AsDispensable(); .AsDispensable();
AddOutput("BatchMetrics", "");
AddOutput("AccumMetrics", "");
AddOutput("AccumStatesInfo", "");
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
...@@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, ...@@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
precision_recall, precision_recall,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>, ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int>, ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>);
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, double>,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, int64_t>,
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
auto* out2 = ctx.Output<Tensor>("AccumStatesInfo"); auto* out2 = ctx.Output<Tensor>("AccumStatesInfo");
const T* predictions_data = in0->data<T>(); const T* predictions_data = in0->data<T>();
const T* labels_data = in1->data<T>(); const int* labels_data = in1->data<int>();
const T* weights_data = in2 ? in2->data<T>() : nullptr; const T* weights_data = in2 ? in2->data<T>() : nullptr;
const T* states_data = in3 ? in3->data<T>() : nullptr; const T* states_data = in3 ? in3->data<T>() : nullptr;
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace()); T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace());
...@@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
out2->mutable_data<T>(ctx.GetPlace()); out2->mutable_data<T>(ctx.GetPlace());
auto accum_states = EigenMatrix<T>::From(*out2); auto accum_states = EigenMatrix<T>::From(*out2);
accum_states.setZero(); accum_states.setZero();
T* accum_states_data = out2->data<T>(ctx.GetPlace()); T* accum_states_data = out2->data<T>();
size_t sample_num = in0->dims()[0]; size_t sample_num = in0->dims()[0];
size_t class_dim = in0->dims()[1]; size_t class_dim = in0->dims()[1];
...@@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
accum_states_data[j * state_var_num + TN] += w; accum_states_data[j * state_var_num + TN] += w;
} }
accum_states_data[max_idx * state_var_num + TN] -= w; accum_states_data[max_idx * state_var_num + TN] -= w;
accum_states_data[labels_data[j] * state_var_num + TN] -= w; accum_states_data[labels_data[i] * state_var_num + TN] -= w;
} }
} }
...@@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
if (tp_count > 0.0 || fn_count > 0.0) { if (tp_count > 0.0 || fn_count > 0.0) {
return tp_count / (tp_count + fn_count); return tp_count / (tp_count + fn_count);
} }
return 1.0 return 1.0;
} }
static inline T CalcF1Score(T precision, T recall) { static inline T CalcF1Score(T precision, T recall) {
...@@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
protected: protected:
void ComputeMetrics(const T* states_data, T* metrics_data, void ComputeMetrics(const T* states_data, T* metrics_data,
size_t state_var_num, size_t class_dim) { size_t state_var_num, size_t class_dim) const {
T total_tp_count = 0; T total_tp_count = 0;
T total_fp_count = 0; T total_fp_count = 0;
T total_fn_count = 0; T total_fn_count = 0;
...@@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> { ...@@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel<T> {
T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count);
T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count);
T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall); T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall);
// fill metrics data // fill metrics data
metrics_data[0] = macro_avg_precision; metrics_data[0] = macro_avg_precision;
......
import unittest
import numpy as np
from op_test import OpTest
def calc_precision(tp_count, fp_count):
if tp_count > 0.0 or fp_count > 0.0:
return tp_count / (tp_count + fp_count)
return 1.0
def calc_recall(tp_count, fn_count):
if tp_count > 0.0 or fn_count > 0.0:
return tp_count / (tp_count + fn_count)
return 1.0
def calc_f1_score(precision, recall):
if precision > 0.0 or recall > 0.0:
return 2 * precision * recall / (precision + recall)
return 0.0
def get_states(predictions, labels, weights=None):
ins_num = predictions.shape[0]
class_num = predictions.shape[1]
# TP FP TN FN
states = np.zeros((class_num, 4)).astype('float32')
for i in xrange(ins_num):
w = weights[i] if weights is not None else 1.0
max_idx = np.argmax(predictions[i])
if max_idx == labels[i][0]:
states[max_idx][0] += w
for j in xrange(class_num):
states[j][2] += w
states[max_idx][2] -= w
else:
states[labels[i][0]][3] += w
states[max_idx][1] += w
for j in xrange(class_num):
states[j][2] += w
states[labels[i][0]][2] -= w
states[max_idx][2] -= w
return states
def compute_metrics(states):
class_num = states.shape[0]
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in xrange(class_num):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= class_num
macro_avg_recall /= class_num
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
metrics.append(micro_avg_precision)
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
metrics.append(micro_avg_recall)
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
return np.array(metrics).astype('float32')
class TestPrecisionRecallOp_0(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels)
metrics = compute_metrics(states)
self.inputs = {'Predictions': predictions, 'Labels': labels}
self.outputs = {
'BatchMetrics': metrics,
'AccumMetrics': metrics,
'AccumStatesInfo': states
}
def test_check_output(self):
self.check_output()
class TestPrecisionRecallOp_1(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = get_states(predictions, labels, weights)
metrics = compute_metrics(states)
self.inputs = {
'Predictions': predictions,
'Labels': labels,
'Weights': weights
}
self.outputs = {
'BatchMetrics': metrics,
'AccumMetrics': metrics,
'AccumStatesInfo': states
}
def test_check_output(self):
self.check_output()
class TestPrecisionRecallOp_2(OpTest):
def setUp(self):
self.op_type = "precision_recall"
ins_num = 64
class_num = 10
predictions = np.random.uniform(0, 1.0,
(ins_num, class_num)).astype('float32')
weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32')
predictions = np.random.random((ins_num, class_num)).astype('float32')
labels = np.random.choice(xrange(class_num), ins_num).reshape(
(ins_num, 1)).astype('int32')
states = np.random.randint(0, 30, (class_num, 4)).astype('float32')
accum_states = get_states(predictions, labels, weights)
batch_metrics = compute_metrics(accum_states)
accum_states += states
accum_metrics = compute_metrics(accum_states)
self.inputs = {
'Predictions': predictions,
'Labels': labels,
'Weights': weights,
'StatesInfo': states
}
self.outputs = {
'BatchMetrics': batch_metrics,
'AccumMetrics': accum_metrics,
'AccumStatesInfo': accum_states
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册