test_atan2_op.py 5.3 KB
Newer Older
R
ronnywang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

17
import numpy as np
18
from eager_op_test import OpTest, convert_float_to_uint16
19

R
ronnywang 已提交
20
import paddle
21
from paddle.fluid import core
R
ronnywang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35

paddle.enable_static()
np.random.seed(0)


def atan2_grad(x1, x2, dout):
    dx1 = dout * x2 / (x1 * x1 + x2 * x2)
    dx2 = -dout * x1 / (x1 * x1 + x2 * x2)
    return dx1, dx2


class TestAtan2(OpTest):
    def setUp(self):
        self.op_type = "atan2"
H
hong 已提交
36
        self.python_api = paddle.atan2
37
        self.check_cinn = True
R
ronnywang 已提交
38 39 40 41 42 43 44 45 46 47
        self.init_dtype()

        x1 = np.random.uniform(-1, -0.1, [15, 17]).astype(self.dtype)
        x2 = np.random.uniform(0.1, 1, [15, 17]).astype(self.dtype)
        out = np.arctan2(x1, x2)

        self.inputs = {'X1': x1, 'X2': x2}
        self.outputs = {'Out': out}

    def test_check_grad(self):
48
        self.check_grad(['X1', 'X2'], 'Out', check_cinn=self.check_cinn)
R
ronnywang 已提交
49 50

    def test_check_output(self):
51
        self.check_output(check_cinn=self.check_cinn)
R
ronnywang 已提交
52 53 54 55 56 57 58 59 60 61 62

    def init_dtype(self):
        self.dtype = np.float64


class TestAtan2_float(TestAtan2):
    def init_dtype(self):
        self.dtype = np.float32

    def test_check_grad(self):
        if self.dtype not in [np.int32, np.int64]:
63 64 65 66 67 68 69 70
            self.check_grad(
                ['X1', 'X2'],
                'Out',
                user_defined_grads=atan2_grad(
                    self.inputs['X1'],
                    self.inputs['X2'],
                    1 / self.inputs['X1'].size,
                ),
71
                check_cinn=self.check_cinn,
72
            )
R
ronnywang 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107


class TestAtan2_float16(TestAtan2_float):
    def init_dtype(self):
        self.dtype = np.float16


class TestAtan2_int32(TestAtan2_float):
    def init_dtype(self):
        self.dtype = np.int32


class TestAtan2_int64(TestAtan2_float):
    def init_dtype(self):
        self.dtype = np.int64


class TestAtan2API(unittest.TestCase):
    def init_dtype(self):
        self.dtype = 'float64'
        self.shape = [11, 17]

    def setUp(self):
        self.init_dtype()
        self.x1 = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
        self.x2 = np.random.uniform(-1, -0.1, self.shape).astype(self.dtype)
        self.place = [paddle.CPUPlace()]
        if core.is_compiled_with_cuda():
            self.place.append(paddle.CUDAPlace(0))

    def test_static_api(self):
        paddle.enable_static()

        def run(place):
            with paddle.static.program_guard(paddle.static.Program()):
108 109
                X1 = paddle.static.data('X1', self.shape, dtype=self.dtype)
                X2 = paddle.static.data('X2', self.shape, dtype=self.dtype)
R
ronnywang 已提交
110 111 112 113 114
                out = paddle.atan2(X1, X2)
                exe = paddle.static.Executor(place)
                res = exe.run(feed={'X1': self.x1, 'X2': self.x2})
            out_ref = np.arctan2(self.x1, self.x2)
            for r in res:
115
                np.testing.assert_allclose(out_ref, r, rtol=1e-05)
R
ronnywang 已提交
116 117 118 119 120 121 122 123 124 125 126

        for place in self.place:
            run(place)

    def test_dygraph_api(self):
        def run(place):
            paddle.disable_static(place)
            X1 = paddle.to_tensor(self.x1)
            X2 = paddle.to_tensor(self.x2)
            out = paddle.atan2(X1, X2)
            out_ref = np.arctan2(self.x1, self.x2)
127
            np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
R
ronnywang 已提交
128 129 130 131 132 133
            paddle.enable_static()

        for place in self.place:
            run(place)


134 135 136 137 138 139 140 141 142 143
@unittest.skipIf(
    not core.is_compiled_with_cuda()
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not compiled with CUDA and not support the bfloat16",
)
class TestAtan2BF16OP(OpTest):
    def setUp(self):
        self.op_type = 'atan2'
        self.python_api = paddle.atan2
        self.dtype = np.uint16
144
        self.check_cinn = True
145 146 147 148 149 150 151 152 153 154 155 156
        x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32')
        x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32')
        out = np.arctan2(x1, x2)

        self.inputs = {
            'X1': convert_float_to_uint16(x1),
            'X2': convert_float_to_uint16(x2),
        }
        self.outputs = {'Out': convert_float_to_uint16(out)}

    def test_check_output(self):
        place = core.CUDAPlace(0)
157
        self.check_output_with_place(place, check_cinn=self.check_cinn)
158 159 160

    def test_check_grad(self):
        place = core.CUDAPlace(0)
161 162 163
        self.check_grad_with_place(
            place, ['X1', 'X2'], 'Out', check_cinn=self.check_cinn
        )
164 165


166 167 168 169 170
class TestAtan2Error(unittest.TestCase):
    def test_mismatch(self):
        paddle.enable_static()

        def test_mismatch_numel():
171 172
            X = paddle.static.data('X', (1,), dtype=np.float64)
            Y = paddle.static.data('Y', (0,), dtype=np.float64)
173 174 175 176 177
            out = paddle.atan2(X, Y)

        self.assertRaises(ValueError, test_mismatch_numel)


R
ronnywang 已提交
178
if __name__ == '__main__':
H
hong 已提交
179
    paddle.enable_static()
R
ronnywang 已提交
180
    unittest.main()