test_auc_op.py 5.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
typhoonzero 已提交
15
import unittest
16

T
typhoonzero 已提交
17
import numpy as np
W
wanghuancoder 已提交
18
from eager_op_test import OpTest
19

20
import paddle
21
from paddle import fluid
T
typhoonzero 已提交
22 23 24 25 26


class TestAucOp(OpTest):
    def setUp(self):
        self.op_type = "auc"
武毅 已提交
27
        pred = np.random.random((128, 2)).astype("float32")
P
peizhilin 已提交
28
        labels = np.random.randint(0, 2, (128, 1)).astype("int64")
T
typhoonzero 已提交
29
        num_thresholds = 200
30
        slide_steps = 1
T
tangwei12 已提交
31

32
        stat_pos = np.zeros(
33 34
            (1 + slide_steps) * (num_thresholds + 1) + 1,
        ).astype("int64")
35
        stat_neg = np.zeros(
36 37
            (1 + slide_steps) * (num_thresholds + 1) + 1,
        ).astype("int64")
W
Wu Yi 已提交
38 39

        self.inputs = {
Q
Qiao Longfei 已提交
40
            'Predict': pred,
W
Wu Yi 已提交
41
            'Label': labels,
T
tangwei12 已提交
42
            "StatPos": stat_pos,
43
            "StatNeg": stat_neg,
W
Wu Yi 已提交
44
        }
T
tangwei12 已提交
45 46 47
        self.attrs = {
            'curve': 'ROC',
            'num_thresholds': num_thresholds,
48
            "slide_steps": slide_steps,
T
tangwei12 已提交
49
        }
T
typhoonzero 已提交
50

51
        python_auc = paddle.metric.Auc(
52 53
            name="auc", curve='ROC', num_thresholds=num_thresholds
        )
Q
Qiao Longfei 已提交
54
        python_auc.update(pred, labels)
T
typhoonzero 已提交
55

56
        pos = python_auc._stat_pos.tolist() * 2
57
        pos.append(1)
58
        neg = python_auc._stat_neg.tolist() * 2
59
        neg.append(1)
W
Wu Yi 已提交
60
        self.outputs = {
61
            'AUC': np.array(python_auc.accumulate()),
62
            'StatPosOut': np.array(pos),
63
            'StatNegOut': np.array(neg),
W
Wu Yi 已提交
64
        }
T
typhoonzero 已提交
65 66

    def test_check_output(self):
W
wanghuancoder 已提交
67
        self.check_output(check_dygraph=False)
T
typhoonzero 已提交
68 69


H
hutuxian 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
class TestGlobalAucOp(OpTest):
    def setUp(self):
        self.op_type = "auc"
        pred = np.random.random((128, 2)).astype("float32")
        labels = np.random.randint(0, 2, (128, 1)).astype("int64")
        num_thresholds = 200
        slide_steps = 0

        stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64")
        stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64")

        self.inputs = {
            'Predict': pred,
            'Label': labels,
            "StatPos": stat_pos,
85
            "StatNeg": stat_neg,
H
hutuxian 已提交
86 87 88 89
        }
        self.attrs = {
            'curve': 'ROC',
            'num_thresholds': num_thresholds,
90
            "slide_steps": slide_steps,
H
hutuxian 已提交
91 92
        }

93
        python_auc = paddle.metric.Auc(
94 95
            name="auc", curve='ROC', num_thresholds=num_thresholds
        )
H
hutuxian 已提交
96 97 98 99 100
        python_auc.update(pred, labels)

        pos = python_auc._stat_pos
        neg = python_auc._stat_neg
        self.outputs = {
101
            'AUC': np.array(python_auc.accumulate()),
102 103
            'StatPosOut': np.array([pos]),
            'StatNegOut': np.array([neg]),
H
hutuxian 已提交
104 105 106
        }

    def test_check_output(self):
W
wanghuancoder 已提交
107
        self.check_output(check_dygraph=False)
H
hutuxian 已提交
108 109


110 111 112 113 114
class TestAucAPI(unittest.TestCase):
    def test_static(self):
        paddle.enable_static()
        data = paddle.static.data(name="input", shape=[-1, 1], dtype="float32")
        label = paddle.static.data(name="label", shape=[4], dtype="int64")
115 116 117 118 119 120
        ins_tag_weight = paddle.static.data(
            name="ins_tag_weight", shape=[4], dtype="float32"
        )
        result = paddle.static.auc(
            input=data, label=label, ins_tag_weight=ins_tag_weight
        )
121 122 123 124 125 126 127 128 129 130

        place = paddle.CPUPlace()
        exe = paddle.static.Executor(place)

        exe.run(paddle.static.default_startup_program())

        x = np.array([[0.0474], [0.5987], [0.7109], [0.9997]]).astype("float32")

        y = np.array([0, 0, 1, 0]).astype('int64')
        z = np.array([1, 1, 1, 1]).astype('float32')
131 132 133 134
        (output,) = exe.run(
            feed={"input": x, "label": y, "ins_tag_weight": z},
            fetch_list=[result[0]],
        )
135
        auc_np = np.array(0.66666667).astype("float32")
136
        np.testing.assert_allclose(output, auc_np, rtol=1e-05)
137
        assert auc_np.shape == auc_np.shape
138 139


140 141 142 143 144
class TestAucOpError(unittest.TestCase):
    def test_errors(self):
        with fluid.program_guard(fluid.Program(), fluid.Program()):

            def test_type1():
145 146 147 148 149 150
                data1 = paddle.static.data(
                    name="input1", shape=[-1, 2], dtype="int"
                )
                label1 = paddle.static.data(
                    name="label1", shape=[-1], dtype="int"
                )
151 152 153 154 155 156
                ins_tag_w1 = paddle.static.data(
                    name="label1", shape=[-1], dtype="int"
                )
                result1 = paddle.static.auc(
                    input=data1, label=label1, ins_tag_weight=ins_tag_w1
                )
157 158 159 160

            self.assertRaises(TypeError, test_type1)

            def test_type2():
161
                data2 = paddle.static.data(
162 163
                    name="input2", shape=[-1, 2], dtype="float32"
                )
164 165 166
                label2 = paddle.static.data(
                    name="label2", shape=[-1], dtype="float32"
                )
167
                result2 = paddle.static.auc(input=data2, label=label2)
168 169 170 171 172

            self.assertRaises(TypeError, test_type2)


if __name__ == '__main__':
武毅 已提交
173
    unittest.main()