From 65dbbd57af4016953338b27e80aa05cfed62c220 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 26 Oct 2017 22:42:44 +0800 Subject: [PATCH] Add and pass unittests. --- paddle/operators/precision_recall_op.cc | 21 ++- paddle/operators/precision_recall_op.h | 14 +- .../tests/test_precision_recall_op.py | 164 ++++++++++++++++++ 3 files changed, 188 insertions(+), 11 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_precision_recall_op.py diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc index 22eaa3f36e..47a16b9461 100644 --- a/paddle/operators/precision_recall_op.cc +++ b/paddle/operators/precision_recall_op.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/operators/precision_recall_op.h" + namespace paddle { namespace operators { @@ -37,13 +39,15 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { if (ctx->HasInput("Weights")) { auto weights_dims = ctx->GetInputDim("Weights"); - PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1}, + PADDLE_ENFORCE_EQ(weights_dims, + framework::make_ddim({predictions_dims[0], 1}), "The shape of Input(Weights) should be " "[batch_size, 1]."); } if (ctx->HasInput("StatesInfo")) { auto states_dims = ctx->GetInputDim("StatesInfo"); - PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4}, + PADDLE_ENFORCE_EQ(states_dims, + framework::make_ddim({predictions_dims[1], 4}), "The shape of Input(StatesInfo) should be " "[class_number, 4]."); } @@ -71,6 +75,12 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { // [ TP, FP, TN, FN ] ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); } + + protected: + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Predictions")->type()); + } }; class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { @@ -98,6 +108,9 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { "provided, current state will be accumulated to this state and " "the accumulation state will be as the output state.") .AsDispensable(); + AddOutput("BatchMetrics", ""); + AddOutput("AccumMetrics", ""); + AddOutput("AccumStatesInfo", ""); AddComment(R"DOC( )DOC"); @@ -113,6 +126,4 @@ REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, REGISTER_OP_CPU_KERNEL( precision_recall, ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, - ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel); diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h index 7ed5f2387e..3bc638ea44 100644 --- a/paddle/operators/precision_recall_op.h +++ b/paddle/operators/precision_recall_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { @@ -37,7 +39,7 @@ class PrecisionRecallKernel : public framework::OpKernel { auto* out2 = ctx.Output("AccumStatesInfo"); const T* predictions_data = in0->data(); - const T* labels_data = in1->data(); + const int* labels_data = in1->data(); const T* weights_data = in2 ? in2->data() : nullptr; const T* states_data = in3 ? in3->data() : nullptr; T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); @@ -45,7 +47,7 @@ class PrecisionRecallKernel : public framework::OpKernel { out2->mutable_data(ctx.GetPlace()); auto accum_states = EigenMatrix::From(*out2); accum_states.setZero(); - T* accum_states_data = out2->data(ctx.GetPlace()); + T* accum_states_data = out2->data(); size_t sample_num = in0->dims()[0]; size_t class_dim = in0->dims()[1]; @@ -76,7 +78,7 @@ class PrecisionRecallKernel : public framework::OpKernel { accum_states_data[j * state_var_num + TN] += w; } accum_states_data[max_idx * state_var_num + TN] -= w; - accum_states_data[labels_data[j] * state_var_num + TN] -= w; + accum_states_data[labels_data[i] * state_var_num + TN] -= w; } } @@ -108,7 +110,7 @@ class PrecisionRecallKernel : public framework::OpKernel { if (tp_count > 0.0 || fn_count > 0.0) { return tp_count / (tp_count + fn_count); } - return 1.0 + return 1.0; } static inline T CalcF1Score(T precision, T recall) { @@ -120,7 +122,7 @@ class PrecisionRecallKernel : public framework::OpKernel { protected: void ComputeMetrics(const T* states_data, T* metrics_data, - size_t state_var_num, size_t class_dim) { + size_t state_var_num, size_t class_dim) const { T total_tp_count = 0; T total_fp_count = 0; T total_fn_count = 0; @@ -143,7 +145,7 @@ class PrecisionRecallKernel : public framework::OpKernel { T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); - T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall); + T micro_f1_score = CalcF1Score(micro_avg_precision, micro_avg_recall); // fill metrics data metrics_data[0] = macro_avg_precision; diff --git a/python/paddle/v2/framework/tests/test_precision_recall_op.py b/python/paddle/v2/framework/tests/test_precision_recall_op.py new file mode 100644 index 0000000000..33efd717d1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_precision_recall_op.py @@ -0,0 +1,164 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def calc_precision(tp_count, fp_count): + if tp_count > 0.0 or fp_count > 0.0: + return tp_count / (tp_count + fp_count) + return 1.0 + + +def calc_recall(tp_count, fn_count): + if tp_count > 0.0 or fn_count > 0.0: + return tp_count / (tp_count + fn_count) + return 1.0 + + +def calc_f1_score(precision, recall): + if precision > 0.0 or recall > 0.0: + return 2 * precision * recall / (precision + recall) + return 0.0 + + +def get_states(predictions, labels, weights=None): + ins_num = predictions.shape[0] + class_num = predictions.shape[1] + # TP FP TN FN + states = np.zeros((class_num, 4)).astype('float32') + for i in xrange(ins_num): + w = weights[i] if weights is not None else 1.0 + max_idx = np.argmax(predictions[i]) + if max_idx == labels[i][0]: + states[max_idx][0] += w + for j in xrange(class_num): + states[j][2] += w + states[max_idx][2] -= w + else: + states[labels[i][0]][3] += w + states[max_idx][1] += w + for j in xrange(class_num): + states[j][2] += w + states[labels[i][0]][2] -= w + states[max_idx][2] -= w + return states + + +def compute_metrics(states): + class_num = states.shape[0] + total_tp_count = 0.0 + total_fp_count = 0.0 + total_fn_count = 0.0 + macro_avg_precision = 0.0 + macro_avg_recall = 0.0 + for i in xrange(class_num): + total_tp_count += states[i][0] + total_fp_count += states[i][1] + total_fn_count += states[i][3] + macro_avg_precision += calc_precision(states[i][0], states[i][1]) + macro_avg_recall += calc_recall(states[i][0], states[i][3]) + metrics = [] + macro_avg_precision /= class_num + macro_avg_recall /= class_num + metrics.append(macro_avg_precision) + metrics.append(macro_avg_recall) + metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall)) + micro_avg_precision = calc_precision(total_tp_count, total_fp_count) + metrics.append(micro_avg_precision) + micro_avg_recall = calc_recall(total_tp_count, total_fn_count) + metrics.append(micro_avg_recall) + metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall)) + return np.array(metrics).astype('float32') + + +class TestPrecisionRecallOp_0(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = get_states(predictions, labels) + metrics = compute_metrics(states) + + self.inputs = {'Predictions': predictions, 'Labels': labels} + + self.outputs = { + 'BatchMetrics': metrics, + 'AccumMetrics': metrics, + 'AccumStatesInfo': states + } + + def test_check_output(self): + self.check_output() + + +class TestPrecisionRecallOp_1(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + predictions = np.random.random((ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + + states = get_states(predictions, labels, weights) + metrics = compute_metrics(states) + self.inputs = { + 'Predictions': predictions, + 'Labels': labels, + 'Weights': weights + } + + self.outputs = { + 'BatchMetrics': metrics, + 'AccumMetrics': metrics, + 'AccumStatesInfo': states + } + + def test_check_output(self): + self.check_output() + + +class TestPrecisionRecallOp_2(OpTest): + def setUp(self): + self.op_type = "precision_recall" + ins_num = 64 + class_num = 10 + predictions = np.random.uniform(0, 1.0, + (ins_num, class_num)).astype('float32') + weights = np.random.uniform(0, 1.0, (ins_num, 1)).astype('float32') + predictions = np.random.random((ins_num, class_num)).astype('float32') + labels = np.random.choice(xrange(class_num), ins_num).reshape( + (ins_num, 1)).astype('int32') + states = np.random.randint(0, 30, (class_num, 4)).astype('float32') + + accum_states = get_states(predictions, labels, weights) + batch_metrics = compute_metrics(accum_states) + accum_states += states + accum_metrics = compute_metrics(accum_states) + + self.inputs = { + 'Predictions': predictions, + 'Labels': labels, + 'Weights': weights, + 'StatesInfo': states + } + + self.outputs = { + 'BatchMetrics': batch_metrics, + 'AccumMetrics': accum_metrics, + 'AccumStatesInfo': accum_states + } + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab