test_auc_op.py 5.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
typhoonzero 已提交
15
import unittest
16

T
typhoonzero 已提交
17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19

20
import paddle
21
from paddle import fluid
22
from paddle.fluid import metrics
T
typhoonzero 已提交
23 24 25 26 27


class TestAucOp(OpTest):
    def setUp(self):
        self.op_type = "auc"
武毅 已提交
28
        pred = np.random.random((128, 2)).astype("float32")
P
peizhilin 已提交
29
        labels = np.random.randint(0, 2, (128, 1)).astype("int64")
T
typhoonzero 已提交
30
        num_thresholds = 200
31
        slide_steps = 1
T
tangwei12 已提交
32

33
        stat_pos = np.zeros(
34 35
            (1 + slide_steps) * (num_thresholds + 1) + 1,
        ).astype("int64")
36
        stat_neg = np.zeros(
37 38
            (1 + slide_steps) * (num_thresholds + 1) + 1,
        ).astype("int64")
W
Wu Yi 已提交
39 40

        self.inputs = {
Q
Qiao Longfei 已提交
41
            'Predict': pred,
W
Wu Yi 已提交
42
            'Label': labels,
T
tangwei12 已提交
43
            "StatPos": stat_pos,
44
            "StatNeg": stat_neg,
W
Wu Yi 已提交
45
        }
T
tangwei12 已提交
46 47 48
        self.attrs = {
            'curve': 'ROC',
            'num_thresholds': num_thresholds,
49
            "slide_steps": slide_steps,
T
tangwei12 已提交
50
        }
T
typhoonzero 已提交
51

52 53 54
        python_auc = metrics.Auc(
            name="auc", curve='ROC', num_thresholds=num_thresholds
        )
Q
Qiao Longfei 已提交
55
        python_auc.update(pred, labels)
T
typhoonzero 已提交
56

57 58 59 60
        pos = python_auc._stat_pos * 2
        pos.append(1)
        neg = python_auc._stat_neg * 2
        neg.append(1)
W
Wu Yi 已提交
61
        self.outputs = {
T
tangwei12 已提交
62
            'AUC': np.array(python_auc.eval()),
63
            'StatPosOut': np.array(pos),
64
            'StatNegOut': np.array(neg),
W
Wu Yi 已提交
65
        }
T
typhoonzero 已提交
66 67

    def test_check_output(self):
W
wanghuancoder 已提交
68
        self.check_output(check_dygraph=False)
T
typhoonzero 已提交
69 70


H
hutuxian 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
class TestGlobalAucOp(OpTest):
    def setUp(self):
        self.op_type = "auc"
        pred = np.random.random((128, 2)).astype("float32")
        labels = np.random.randint(0, 2, (128, 1)).astype("int64")
        num_thresholds = 200
        slide_steps = 0

        stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64")
        stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64")

        self.inputs = {
            'Predict': pred,
            'Label': labels,
            "StatPos": stat_pos,
86
            "StatNeg": stat_neg,
H
hutuxian 已提交
87 88 89 90
        }
        self.attrs = {
            'curve': 'ROC',
            'num_thresholds': num_thresholds,
91
            "slide_steps": slide_steps,
H
hutuxian 已提交
92 93
        }

94 95 96
        python_auc = metrics.Auc(
            name="auc", curve='ROC', num_thresholds=num_thresholds
        )
H
hutuxian 已提交
97 98 99 100 101 102 103
        python_auc.update(pred, labels)

        pos = python_auc._stat_pos
        neg = python_auc._stat_neg
        self.outputs = {
            'AUC': np.array(python_auc.eval()),
            'StatPosOut': np.array(pos),
104
            'StatNegOut': np.array(neg),
H
hutuxian 已提交
105 106 107
        }

    def test_check_output(self):
W
wanghuancoder 已提交
108
        self.check_output(check_dygraph=False)
H
hutuxian 已提交
109 110


111 112 113 114 115
class TestAucAPI(unittest.TestCase):
    def test_static(self):
        paddle.enable_static()
        data = paddle.static.data(name="input", shape=[-1, 1], dtype="float32")
        label = paddle.static.data(name="label", shape=[4], dtype="int64")
116 117 118 119 120 121
        ins_tag_weight = paddle.static.data(
            name="ins_tag_weight", shape=[4], dtype="float32"
        )
        result = paddle.static.auc(
            input=data, label=label, ins_tag_weight=ins_tag_weight
        )
122 123 124 125 126 127 128 129 130 131

        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)

        exe.run(paddle.static.default_startup_program())

        x = np.array([[0.0474], [0.5987], [0.7109], [0.9997]]).astype("float32")

        y = np.array([0, 0, 1, 0]).astype('int64')
        z = np.array([1, 1, 1, 1]).astype('float32')
132 133 134 135
        (output,) = exe.run(
            feed={"input": x, "label": y, "ins_tag_weight": z},
            fetch_list=[result[0]],
        )
136
        auc_np = np.array([0.66666667]).astype("float32")
137
        np.testing.assert_allclose(output, auc_np, rtol=1e-05)
138 139


140 141 142 143 144
class TestAucOpError(unittest.TestCase):
    def test_errors(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):

            def test_type1():
145 146 147 148 149 150
                data1 = paddle.static.data(
                    name="input1", shape=[-1, 2], dtype="int"
                )
                label1 = paddle.static.data(
                    name="label1", shape=[-1], dtype="int"
                )
151 152 153 154 155 156
                ins_tag_w1 = paddle.static.data(
                    name="label1", shape=[-1], dtype="int"
                )
                result1 = paddle.static.auc(
                    input=data1, label=label1, ins_tag_weight=ins_tag_w1
                )
157 158 159 160

            self.assertRaises(TypeError, test_type1)

            def test_type2():
161
                data2 = paddle.static.data(
162 163
                    name="input2", shape=[-1, 2], dtype="float32"
                )
164 165 166
                label2 = paddle.static.data(
                    name="label2", shape=[-1], dtype="float32"
                )
167
                result2 = paddle.static.auc(input=data2, label=label2)
168 169 170 171 172

            self.assertRaises(TypeError, test_type2)


if __name__ == '__main__':
武毅 已提交
173
    unittest.main()