test_allclose_op.py 6.9 KB
Newer Older
Z
Zhen Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
18
import paddle
Z
Zhen Wang 已提交
19 20 21


class TestAllcloseOp(OpTest):
22

Z
Zhen Wang 已提交
23 24 25
    def set_args(self):
        self.input = np.array([10000., 1e-07]).astype("float32")
        self.other = np.array([10000.1, 1e-08]).astype("float32")
H
huangxu96 已提交
26 27
        self.rtol = np.array([1e-05]).astype("float64")
        self.atol = np.array([1e-08]).astype("float64")
Z
Zhen Wang 已提交
28 29 30 31 32
        self.equal_nan = False

    def setUp(self):
        self.set_args()
        self.op_type = "allclose"
33
        self.python_api = paddle.allclose
H
huangxu96 已提交
34 35 36 37 38
        self.inputs = {
            'Input': self.input,
            'Other': self.other,
            "Rtol": self.rtol,
            "Atol": self.atol
Z
Zhen Wang 已提交
39
        }
H
huangxu96 已提交
40
        self.attrs = {'equal_nan': self.equal_nan}
Z
Zhen Wang 已提交
41
        self.outputs = {
42 43 44 45 46 47 48
            'Out':
            np.array([
                np.allclose(self.inputs['Input'],
                            self.inputs['Other'],
                            rtol=self.rtol,
                            atol=self.atol,
                            equal_nan=self.equal_nan)
Z
Zhen Wang 已提交
49 50 51 52
            ])
        }

    def test_check_output(self):
53
        self.check_output(check_eager=True)
Z
Zhen Wang 已提交
54 55


56
class TestAllcloseOpException(TestAllcloseOp):
57

58
    def test_check_output(self):
59

60 61 62
        def test_rtol_num():
            self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64")
            self.inputs['Atol'] = np.array([1e-08]).astype("float64")
63
            self.check_output(check_eager=True)
64 65 66 67 68 69

        self.assertRaises(ValueError, test_rtol_num)

        def test_rtol_type():
            self.inputs['Rtol'] = np.array([5]).astype("int32")
            self.inputs['Atol'] = np.array([1e-08]).astype("float64")
70
            self.check_output(check_eager=True)
71 72 73 74 75 76

        self.assertRaises(ValueError, test_rtol_type)

        def test_atol_num():
            self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
            self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64")
77
            self.check_output(check_eager=True)
78 79 80 81 82 83

        self.assertRaises(ValueError, test_atol_num)

        def test_atol_type():
            self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
            self.inputs['Atol'] = np.array([8]).astype("int32")
84
            self.check_output(check_eager=True)
85 86 87 88

        self.assertRaises(ValueError, test_atol_type)


Z
Zhen Wang 已提交
89
class TestAllcloseOpSmallNum(TestAllcloseOp):
90

Z
Zhen Wang 已提交
91 92 93
    def set_args(self):
        self.input = np.array([10000., 1e-08]).astype("float32")
        self.other = np.array([10000.1, 1e-09]).astype("float32")
H
huangxu96 已提交
94 95
        self.rtol = np.array([1e-05]).astype("float64")
        self.atol = np.array([1e-08]).astype("float64")
Z
Zhen Wang 已提交
96 97 98 99
        self.equal_nan = False


class TestAllcloseOpNanFalse(TestAllcloseOp):
100

Z
Zhen Wang 已提交
101 102 103
    def set_args(self):
        self.input = np.array([1.0, float('nan')]).astype("float32")
        self.other = np.array([1.0, float('nan')]).astype("float32")
H
huangxu96 已提交
104 105
        self.rtol = np.array([1e-05]).astype("float64")
        self.atol = np.array([1e-08]).astype("float64")
Z
Zhen Wang 已提交
106 107 108 109
        self.equal_nan = False


class TestAllcloseOpNanTrue(TestAllcloseOp):
110

Z
Zhen Wang 已提交
111 112 113
    def set_args(self):
        self.input = np.array([1.0, float('nan')]).astype("float32")
        self.other = np.array([1.0, float('nan')]).astype("float32")
H
huangxu96 已提交
114 115
        self.rtol = np.array([1e-05]).astype("float64")
        self.atol = np.array([1e-08]).astype("float64")
Z
Zhen Wang 已提交
116 117 118
        self.equal_nan = True


119
class TestAllcloseDygraph(unittest.TestCase):
120

121 122 123 124 125 126 127 128 129 130 131 132 133
    def test_api_case(self):
        paddle.disable_static()
        x_data = np.random.rand(10, 10)
        y_data = np.random.rand(10, 10)
        x = paddle.to_tensor(x_data)
        y = paddle.to_tensor(y_data)
        out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08)
        expected_out = np.allclose(x_data, y_data, rtol=1e-05, atol=1e-08)
        self.assertTrue((out.numpy() == expected_out).all(), True)
        paddle.enable_static()


class TestAllcloseError(unittest.TestCase):
134

135
    def test_input_dtype(self):
136

137 138 139
        def test_x_dtype():
            with paddle.static.program_guard(paddle.static.Program(),
                                             paddle.static.Program()):
140 141
                x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
                y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
142 143 144 145 146 147 148
                result = paddle.allclose(x, y)

        self.assertRaises(TypeError, test_x_dtype)

        def test_y_dtype():
            with paddle.static.program_guard(paddle.static.Program(),
                                             paddle.static.Program()):
149 150
                x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
                y = paddle.fluid.data(name='y', shape=[10, 10], dtype='int32')
151 152 153 154 155
                result = paddle.allclose(x, y)

        self.assertRaises(TypeError, test_y_dtype)

    def test_attr(self):
156 157
        x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
        y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

        def test_rtol():
            result = paddle.allclose(x, y, rtol=True)

        self.assertRaises(TypeError, test_rtol)

        def test_atol():
            result = paddle.allclose(x, y, rtol=True)

        self.assertRaises(TypeError, test_atol)

        def test_equal_nan():
            result = paddle.allclose(x, y, equal_nan=1)

        self.assertRaises(TypeError, test_equal_nan)


H
huangxu96 已提交
175
class TestAllcloseOpFloat32(TestAllcloseOp):
176

H
huangxu96 已提交
177 178 179 180 181 182 183 184 185
    def set_args(self):
        self.input = np.array([10.1]).astype("float32")
        self.other = np.array([10]).astype("float32")
        self.rtol = np.array([0.01]).astype("float64")
        self.atol = np.array([0]).astype("float64")
        self.equal_nan = False


class TestAllcloseOpFloat64(TestAllcloseOp):
186

H
huangxu96 已提交
187 188 189 190 191 192 193 194 195
    def set_args(self):
        self.input = np.array([10.1]).astype("float64")
        self.other = np.array([10]).astype("float64")
        self.rtol = np.array([0.01]).astype("float64")
        self.atol = np.array([0]).astype("float64")
        self.equal_nan = False


class TestAllcloseOpLargeDimInput(TestAllcloseOp):
196

H
huangxu96 已提交
197 198 199 200 201 202 203 204 205
    def set_args(self):
        self.input = np.array(np.zeros([2048, 1024])).astype("float64")
        self.other = np.array(np.zeros([2048, 1024])).astype("float64")
        self.input[-1][-1] = 100
        self.rtol = np.array([1e-05]).astype("float64")
        self.atol = np.array([1e-08]).astype("float64")
        self.equal_nan = False


Z
Zhen Wang 已提交
206 207
if __name__ == "__main__":
    unittest.main()