test_complex_cast.py 2.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18 19 20 21 22 23 24 25 26
import numpy as np

import paddle


class TestComplexCastOp(unittest.TestCase):
    def test_complex_to_real(self):
        r = np.random.random(size=[10, 10]) * 10
        i = np.random.random(size=[10, 10])

27
        c_t = paddle.to_tensor(r + i * 1j, dtype='complex64')
28 29 30 31 32 33 34

        self.assertEqual(c_t.cast('int64').dtype, paddle.int64)
        self.assertEqual(c_t.cast('int32').dtype, paddle.int32)
        self.assertEqual(c_t.cast('float32').dtype, paddle.float32)
        self.assertEqual(c_t.cast('float64').dtype, paddle.float64)
        self.assertEqual(c_t.cast('bool').dtype, paddle.bool)

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
        np.testing.assert_allclose(
            c_t.cast('int64').numpy(), r.astype('int64'), rtol=1e-05
        )
        np.testing.assert_allclose(
            c_t.cast('int32').numpy(), r.astype('int32'), rtol=1e-05
        )
        np.testing.assert_allclose(
            c_t.cast('float32').numpy(), r.astype('float32'), rtol=1e-05
        )
        np.testing.assert_allclose(
            c_t.cast('float64').numpy(), r.astype('float64'), rtol=1e-05
        )
        np.testing.assert_allclose(
            c_t.cast('bool').numpy(), r.astype('bool'), rtol=1e-05
        )
50 51 52 53 54 55 56 57

    def test_real_to_complex(self):
        r = np.random.random(size=[10, 10]) * 10
        r_t = paddle.to_tensor(r)

        self.assertEqual(r_t.cast('complex64').dtype, paddle.complex64)
        self.assertEqual(r_t.cast('complex128').dtype, paddle.complex128)

58 59 60 61 62 63
        np.testing.assert_allclose(
            r_t.cast('complex64').real().numpy(), r, rtol=1e-05
        )
        np.testing.assert_allclose(
            r_t.cast('complex128').real().numpy(), r, rtol=1e-05
        )
64 65 66 67 68

    def test_complex64_complex128(self):
        r = np.random.random(size=[10, 10])
        i = np.random.random(size=[10, 10])

69
        c = r + i * 1j
70 71 72 73 74
        c_64 = paddle.to_tensor(c, dtype='complex64')
        c_128 = paddle.to_tensor(c, dtype='complex128')

        self.assertTrue(c_64.cast('complex128').dtype, paddle.complex128)
        self.assertTrue(c_128.cast('complex128').dtype, paddle.complex64)
75 76 77 78 79 80
        np.testing.assert_allclose(
            c_64.cast('complex128').numpy(), c_128.numpy(), rtol=1e-05
        )
        np.testing.assert_allclose(
            c_128.cast('complex128').numpy(), c_64.numpy(), rtol=1e-05
        )
81 82 83 84


if __name__ == '__main__':
    unittest.main()