diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 001d26936886f12efc6eaa0333bb12e4e7118d67..e0eebad08bb6b9a15d9c0f356215404884bee0e9 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -28,8 +28,9 @@ class AucOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input of Label should not be null."); auto predict_width = ctx->GetInputDim("Predict")[1]; - PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2, - "Only support binary classification"); + PADDLE_INFERSHAPE_ENFORCE_LE(ctx, predict_width, 2, + "Only support binary classification," + "prediction dims[1] should be 1 or 2"); auto predict_height = ctx->GetInputDim("Predict")[0]; auto label_height = ctx->GetInputDim("Label")[0]; diff --git a/paddle/fluid/operators/metrics/auc_op.h b/paddle/fluid/operators/metrics/auc_op.h index 4ab5cfe53c67eeaa995d7e955eec63a065c5eec5..6fb4749b35a37dfbb18d322920b2744d7a0882d4 100644 --- a/paddle/fluid/operators/metrics/auc_op.h +++ b/paddle/fluid/operators/metrics/auc_op.h @@ -75,7 +75,10 @@ class AucKernel : public framework::OpKernel { const auto *label_data = label->data(); for (size_t i = 0; i < batch_size; i++) { - auto predict_data = inference_data[i * inference_width + 1]; + // if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob + // if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob + auto predict_data = + inference_data[i * inference_width + (inference_width - 1)]; PADDLE_ENFORCE_LE(predict_data, 1, "The predict data must less or equal 1."); PADDLE_ENFORCE_GE(predict_data, 0, diff --git a/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3e93fa57b081fa1ce0ec6309ee166335b05ec9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import metrics + + +class TestAucSinglePredOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128, 2)).astype("float32") + pred0 = pred[:, 0].reshape(128, 1) + labels = np.random.randint(0, 2, (128, 1)).astype("int64") + num_thresholds = 200 + + stat_pos = np.zeros((num_thresholds + 1, )).astype("int64") + stat_neg = np.zeros((num_thresholds + 1, )).astype("int64") + + self.inputs = { + 'Predict': pred0, + 'Label': labels, + "StatPos": stat_pos, + "StatNeg": stat_neg + } + self.attrs = { + 'curve': 'ROC', + 'num_thresholds': num_thresholds, + "slide_steps": 1 + } + + python_auc = metrics.Auc(name="auc", + curve='ROC', + num_thresholds=num_thresholds) + for i in range(128): + pred[i][1] = pred[i][0] + python_auc.update(pred, labels) + + self.outputs = { + 'AUC': np.array(python_auc.eval()), + 'StatPosOut': np.array(python_auc._stat_pos), + 'StatNegOut': np.array(python_auc._stat_neg) + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()