diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 9a4a30b9cbd0126b0c2b4a334e1548778d91c013..d9bb2982f0f64a89641dfa0bbc40242348308b41 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -49,12 +49,15 @@ class AucOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AUC", {1}); - // slide_steps = slide_steps == 0 ? 1 : slide_steps; - int need_batch_id = slide_steps ? 1 : 0; - ctx->SetOutputDim("StatPosOut", - {(1 + slide_steps) * num_pred_buckets + need_batch_id}); - ctx->SetOutputDim("StatNegOut", - {(1 + slide_steps) * num_pred_buckets + need_batch_id}); + if (slide_steps) { + ctx->SetOutputDim("StatPosOut", + {(1 + slide_steps) * num_pred_buckets + 1}); + ctx->SetOutputDim("StatNegOut", + {(1 + slide_steps) * num_pred_buckets + 1}); + } else { + ctx->SetOutputDim("StatPosOut", {1, num_pred_buckets}); + ctx->SetOutputDim("StatNegOut", {1, num_pred_buckets}); + } } protected: diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index 3517d3ed824d201117550d011c6a24e480c8a773..4ffe65023063aeecce09a5908ff51522cdeca8e8 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -184,9 +184,9 @@ def auc(input, # for global auc # Needn't maintain the batch id stat_pos = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) stat_neg = helper.create_global_variable( - persistable=True, dtype='int64', shape=[num_thresholds + 1]) + persistable=True, dtype='int64', shape=[1, num_thresholds + 1]) for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]: helper.set_variable_initializer( diff --git a/python/paddle/fluid/tests/unittests/test_auc_op.py b/python/paddle/fluid/tests/unittests/test_auc_op.py index 4835c38f5fcb3f33526e6a6c60149eaf22a6726e..a07587fdb2818f4cd34cd716c0be59b6e97df927 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_op.py @@ -64,5 +64,45 @@ class TestAucOp(OpTest): self.check_output() +class TestGlobalAucOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128, 2)).astype("float32") + labels = np.random.randint(0, 2, (128, 1)).astype("int64") + num_thresholds = 200 + slide_steps = 0 + + stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64") + stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64") + + self.inputs = { + 'Predict': pred, + 'Label': labels, + "StatPos": stat_pos, + "StatNeg": stat_neg + } + self.attrs = { + 'curve': 'ROC', + 'num_thresholds': num_thresholds, + "slide_steps": slide_steps + } + + python_auc = metrics.Auc(name="auc", + curve='ROC', + num_thresholds=num_thresholds) + python_auc.update(pred, labels) + + pos = python_auc._stat_pos + neg = python_auc._stat_neg + self.outputs = { + 'AUC': np.array(python_auc.eval()), + 'StatPosOut': np.array(pos), + 'StatNegOut': np.array(neg) + } + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py index 574a820decae7f7e2b47ca7540e9f2484eae781b..5093dc1f990a9026d9953f3d97bed07b6b1c2cb7 100644 --- a/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py +++ b/python/paddle/fluid/tests/unittests/test_auc_single_pred_op.py @@ -67,5 +67,48 @@ class TestAucSinglePredOp(OpTest): self.check_output() +class TestAucGlobalSinglePredOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128, 2)).astype("float32") + pred0 = pred[:, 0].reshape(128, 1) + labels = np.random.randint(0, 2, (128, 1)).astype("int64") + num_thresholds = 200 + slide_steps = 0 + + stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64") + stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64") + + self.inputs = { + 'Predict': pred0, + 'Label': labels, + "StatPos": stat_pos, + "StatNeg": stat_neg + } + self.attrs = { + 'curve': 'ROC', + 'num_thresholds': num_thresholds, + "slide_steps": slide_steps + } + + python_auc = metrics.Auc(name="auc", + curve='ROC', + num_thresholds=num_thresholds) + for i in range(128): + pred[i][1] = pred[i][0] + python_auc.update(pred, labels) + + pos = python_auc._stat_pos + neg = python_auc._stat_neg + self.outputs = { + 'AUC': np.array(python_auc.eval()), + 'StatPosOut': np.array(pos), + 'StatNegOut': np.array(neg) + } + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main()