From 06c7c8c80e2c843afb7c5b156766533a5a389be9 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 26 Oct 2017 11:59:54 +0800 Subject: [PATCH] Add CPU kernel. --- paddle/operators/precision_recall_op.cc | 118 ++++++++++++++++++ paddle/operators/precision_recall_op.h | 159 ++++++++++++++++++++++++ 2 files changed, 277 insertions(+) create mode 100644 paddle/operators/precision_recall_op.cc create mode 100644 paddle/operators/precision_recall_op.h diff --git a/paddle/operators/precision_recall_op.cc b/paddle/operators/precision_recall_op.cc new file mode 100644 index 000000000..22eaa3f36 --- /dev/null +++ b/paddle/operators/precision_recall_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { + +class PrecisionRecallOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // may contains weights and StatesInfo + PADDLE_ENFORCE(ctx->HasInput("Predictions"), + "Input(Predictions) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Labels"), + "Input(Labels) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchMetrics"), + "Output(BatchMetrics) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AccumMetrics"), + "Output(AccumMetrics) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AccumStatesInfo"), + "Output(AccumStatesInfo) should not be null."); + + auto predictions_dims = ctx->GetInputDim("Predictions"); + auto labels_dims = ctx->GetInputDim("Labels"); + + if (ctx->HasInput("Weights")) { + auto weights_dims = ctx->GetInputDim("Weights"); + PADDLE_ENFORCE_EQ(weights_dims, {predictions_dims[0], 1}, + "The shape of Input(Weights) should be " + "[batch_size, 1]."); + } + if (ctx->HasInput("StatesInfo")) { + auto states_dims = ctx->GetInputDim("StatesInfo"); + PADDLE_ENFORCE_EQ(states_dims, {predictions_dims[1], 4}, + "The shape of Input(StatesInfo) should be " + "[class_number, 4]."); + } + PADDLE_ENFORCE_EQ(predictions_dims[0], labels_dims[0], + "The 1st dimension of Input(Predictions) and " + "Input(Labels) both are batch_size and the shape should " + "be the same."); + PADDLE_ENFORCE_EQ(labels_dims[1], 1, + "The 2nd dimension of Input(Labels) " + "contains instance label and the shape should be equal " + "to 1"); + PADDLE_ENFORCE_GE(predictions_dims[1], 1, + "The shape of Input(Predictions)'s 2nd dimension is " + "equal to class number and should be at least 1."); + + // Layouts of BatchMetrics and AccumMetrics both are: + // [ + // macro average precision, macro average recall, macro average F1 score, + // micro average precision, micro average recall, micro average F1 score + // ] + ctx->SetOutputDim("BatchMetrics", {6}); + ctx->SetOutputDim("AccumMetrics", {6}); + // Shape of AccumStatesInfo is [class_number, 4] + // The layout of each row is: + // [ TP, FP, TN, FN ] + ctx->SetOutputDim("AccumStatesInfo", {predictions_dims[1], 4}); + } +}; + +class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PrecisionRecallOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Predictions", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "Each row contains probabilities for an instance which computed " + "by the previous operator."); + AddInput("Labels", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. Each element is a label and the " + "value should be in [0, class_number - 1]."); + AddInput("Weights", + "(Tensor, default Tensor), a 2-D tensor with shape N x 1, " + "where N is the batch size. This input is optional. If provided, " + "weight of instance would be considered when computing metrics.") + .AsDispensable(); + AddInput("StatesInfo", + "(Tensor, default Tensor), a 2-D tensor with shape D x 4, " + "where D is the number of classes. This input is optional. If " + "provided, current state will be accumulated to this state and " + "the accumulation state will be as the output state.") + .AsDispensable(); + + AddComment(R"DOC( +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, + ops::PrecisionRecallOpMaker); +REGISTER_OP_CPU_KERNEL( + precision_recall, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, + ops::PrecisionRecallKernel, diff --git a/paddle/operators/precision_recall_op.h b/paddle/operators/precision_recall_op.h new file mode 100644 index 000000000..7ed5f2387 --- /dev/null +++ b/paddle/operators/precision_recall_op.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +enum StateVariable { TP = 0, FP, TN, FN }; + +template +class PrecisionRecallKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in0 = ctx.Input("Predictions"); + auto* in1 = ctx.Input("Labels"); + auto* in2 = ctx.Input("Weights"); + auto* in3 = ctx.Input("StatesInfo"); + auto* out0 = ctx.Output("BatchMetrics"); + auto* out1 = ctx.Output("AccumMetrics"); + auto* out2 = ctx.Output("AccumStatesInfo"); + + const T* predictions_data = in0->data(); + const T* labels_data = in1->data(); + const T* weights_data = in2 ? in2->data() : nullptr; + const T* states_data = in3 ? in3->data() : nullptr; + T* batch_metrics_data = out0->mutable_data(ctx.GetPlace()); + T* accum_metrics_data = out1->mutable_data(ctx.GetPlace()); + out2->mutable_data(ctx.GetPlace()); + auto accum_states = EigenMatrix::From(*out2); + accum_states.setZero(); + T* accum_states_data = out2->data(ctx.GetPlace()); + + size_t sample_num = in0->dims()[0]; + size_t class_dim = in0->dims()[1]; + size_t state_var_num = 4; // TP FP TN FN + + // get states info for current batch + for (size_t i = 0; i < sample_num; ++i) { + size_t max_idx = 0; + T max_val = predictions_data[i * class_dim]; + for (size_t j = 1; j < class_dim; ++j) { + if (max_val < predictions_data[i * class_dim + j]) { + max_idx = j; + max_val = predictions_data[i * class_dim + j]; + } + } + + T w = weights_data ? weights_data[i] : 1.0; + if (max_idx == labels_data[i]) { + accum_states_data[max_idx * state_var_num + TP] += w; + for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[j * state_var_num + TN] += w; + } + accum_states_data[max_idx * state_var_num + TN] -= w; + } else { + accum_states_data[labels_data[i] * state_var_num + FN] += w; + accum_states_data[max_idx * state_var_num + FP] += w; + for (size_t j = 0; j < class_dim; ++j) { + accum_states_data[j * state_var_num + TN] += w; + } + accum_states_data[max_idx * state_var_num + TN] -= w; + accum_states_data[labels_data[j] * state_var_num + TN] -= w; + } + } + + ComputeMetrics(accum_states_data, batch_metrics_data, state_var_num, + class_dim); + + if (states_data) { + for (size_t i = 0; i < class_dim; ++i) { + for (size_t j = 0; j < state_var_num; ++j) { + size_t idx = i * state_var_num + j; + accum_states_data[idx] += states_data[idx]; + } + } + } + + ComputeMetrics(accum_states_data, accum_metrics_data, state_var_num, + class_dim); + } + + // expose to be reused + static inline T CalcPrecision(T tp_count, T fp_count) { + if (tp_count > 0.0 || fp_count > 0.0) { + return tp_count / (tp_count + fp_count); + } + return 1.0; + } + + static inline T CalcRecall(T tp_count, T fn_count) { + if (tp_count > 0.0 || fn_count > 0.0) { + return tp_count / (tp_count + fn_count); + } + return 1.0 + } + + static inline T CalcF1Score(T precision, T recall) { + if (precision > 0.0 || recall > 0.0) { + return 2 * precision * recall / (precision + recall); + } + return 0.0; + } + + protected: + void ComputeMetrics(const T* states_data, T* metrics_data, + size_t state_var_num, size_t class_dim) { + T total_tp_count = 0; + T total_fp_count = 0; + T total_fn_count = 0; + T macro_avg_precision = 0.0; + T macro_avg_recall = 0.0; + + for (size_t i = 0; i < class_dim; ++i) { + T tp_count = states_data[i * state_var_num + TP]; + T fp_count = states_data[i * state_var_num + FP]; + T fn_count = states_data[i * state_var_num + FN]; + total_tp_count += tp_count; + total_fp_count += fp_count; + total_fn_count += fn_count; + macro_avg_precision += CalcPrecision(tp_count, fp_count); + macro_avg_recall += CalcRecall(tp_count, fn_count); + } + macro_avg_precision /= class_dim; + macro_avg_recall /= class_dim; + T macro_f1_score = CalcF1Score(macro_avg_precision, macro_avg_recall); + + T micro_avg_precision = CalcPrecision(total_tp_count, total_fp_count); + T micro_avg_recall = CalcRecall(total_tp_count, total_fn_count); + T micro_f1_score = CalcRecall(micro_avg_precision, micro_avg_recall); + + // fill metrics data + metrics_data[0] = macro_avg_precision; + metrics_data[1] = macro_avg_recall; + metrics_data[2] = macro_f1_score; + metrics_data[3] = micro_avg_precision; + metrics_data[4] = micro_avg_recall; + metrics_data[5] = micro_f1_score; + } +}; + +} // namespace operators +} // namespace paddle -- GitLab