未验证 提交 ab57d389 编写于 作者: G guru4elephant 提交者: GitHub

make auc op compatible with 1 dim (#18551)

* make auc op compatible with 1 dim
上级 b71b4543
......@@ -28,8 +28,9 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null.");
auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2,
"Only support binary classification");
PADDLE_INFERSHAPE_ENFORCE_LE(ctx, predict_width, 2,
"Only support binary classification,"
"prediction dims[1] should be 1 or 2");
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0];
......
......@@ -75,7 +75,10 @@ class AucKernel : public framework::OpKernel<T> {
const auto *label_data = label->data<int64_t>();
for (size_t i = 0; i < batch_size; i++) {
auto predict_data = inference_data[i * inference_width + 1];
// if predict_data[i] has dim of 2, then predict_data[i][1] is pos prob
// if predict_data[i] has dim of 1, then predict_data[i][0] is pos prob
auto predict_data =
inference_data[i * inference_width + (inference_width - 1)];
PADDLE_ENFORCE_LE(predict_data, 1,
"The predict data must less or equal 1.");
PADDLE_ENFORCE_GE(predict_data, 0,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import metrics
class TestAucSinglePredOp(OpTest):
def setUp(self):
self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32")
pred0 = pred[:, 0].reshape(128, 1)
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
num_thresholds = 200
stat_pos = np.zeros((num_thresholds + 1, )).astype("int64")
stat_neg = np.zeros((num_thresholds + 1, )).astype("int64")
self.inputs = {
'Predict': pred0,
'Label': labels,
"StatPos": stat_pos,
"StatNeg": stat_neg
}
self.attrs = {
'curve': 'ROC',
'num_thresholds': num_thresholds,
"slide_steps": 1
}
python_auc = metrics.Auc(name="auc",
curve='ROC',
num_thresholds=num_thresholds)
for i in range(128):
pred[i][1] = pred[i][0]
python_auc.update(pred, labels)
self.outputs = {
'AUC': np.array(python_auc.eval()),
'StatPosOut': np.array(python_auc._stat_pos),
'StatNegOut': np.array(python_auc._stat_neg)
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册