test_fft.py 44.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import re
import sys
import unittest

import numpy as np
import paddle
import scipy.fft

DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
    DEVICES.append(paddle.CUDAPlace(0))

TEST_CASE_NAME = 'suffix'
# All test case will use float64 for compare percision, refs:
# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64
RTOL = {
    'float32': 1e-03,
    'complex64': 1e-3,
    'float64': 1e-7,
    'complex128': 1e-7
}
ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0}


def rand_x(dims=1,
           dtype='float64',
           min_dim_len=1,
           max_dim_len=10,
           complex=False):
    shape = [np.random.randint(min_dim_len, max_dim_len) for i in range(dims)]
    if complex:
47 48
        return np.random.randn(
            *shape).astype(dtype) + 1.j * np.random.randn(*shape).astype(dtype)
49 50 51 52 53
    else:
        return np.random.randn(*shape).astype(dtype)


def place(devices, key='place'):
54

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def decorate(cls):
        module = sys.modules[cls.__module__].__dict__
        raw_classes = {
            k: v
            for k, v in module.items() if k.startswith(cls.__name__)
        }

        for raw_name, raw_cls in raw_classes.items():
            for d in devices:
                test_cls = dict(raw_cls.__dict__)
                test_cls.update({key: d})
                new_name = raw_name + '.' + d.__class__.__name__
                module[new_name] = type(new_name, (raw_cls, ), test_cls)
            del module[raw_name]
        return cls

    return decorate


def parameterize(fields, values=None):

    fields = [fields] if isinstance(fields, str) else fields
    params = [dict(zip(fields, vals)) for vals in values]

    def decorate(cls):
        test_cls_module = sys.modules[cls.__module__].__dict__
        for k, v in enumerate(params):
            test_cls = dict(cls.__dict__)
            test_cls.update(v)
            name = cls.__name__ + str(k)
            name = name + '.' + v.get('suffix') if v.get('suffix') else name

            test_cls_module[name] = type(name, (cls, ), test_cls)

        for m in list(cls.__dict__):
            if m.startswith("test"):
                delattr(cls, m)
        return cls

    return decorate


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
    [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
101 102 103 104 105
     ('test_x_complex', rand_x(5, complex=True), None, -1, 'backward'),
     ('test_n_grater_input_length', rand_x(5,
                                           max_dim_len=5), 11, -1, 'backward'),
     ('test_n_smaller_than_input_length', rand_x(
         5, min_dim_len=5, complex=True), 3, -1, 'backward'),
106 107 108 109
     ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
     ('test_norm_forward', rand_x(5), None, 3, 'forward'),
     ('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestFft(unittest.TestCase):
110

111
    def test_fft(self):
112 113
        """Test fft with norm condition
        """
114 115
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
116 117 118 119 120
                np.allclose(scipy.fft.fft(self.x, self.n, self.axis, self.norm),
                            paddle.fft.fft(paddle.to_tensor(self.x), self.n,
                                           self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
121 122


123 124 125 126
@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
    [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
127 128 129 130 131
     ('test_x_complex', rand_x(5, complex=True), None, -1, 'backward'),
     ('test_n_grater_input_length', rand_x(5,
                                           max_dim_len=5), 11, -1, 'backward'),
     ('test_n_smaller_than_input_length', rand_x(
         5, min_dim_len=5, complex=True), 3, -1, 'backward'),
132 133 134 135
     ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
     ('test_norm_forward', rand_x(5), None, 3, 'forward'),
     ('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestIfft(unittest.TestCase):
136

137 138 139 140 141
    def test_fft(self):
        """Test ifft with norm condition
        """
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
142 143 144 145 146 147
                np.allclose(scipy.fft.ifft(self.x, self.n, self.axis,
                                           self.norm),
                            paddle.fft.ifft(paddle.to_tensor(self.x), self.n,
                                            self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
148 149


150
@place(DEVICES)
151 152 153 154 155 156 157 158
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
     ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
     ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
     ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
     ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)]
)
159
class TestFftException(unittest.TestCase):
160

161 162 163 164 165 166 167 168
    def test_fft(self):
        """Test fft with buoudary condition
        Test case include:
        - n out of range
        - axis out of range
        - axis type error
        - norm out of range
        """
169
        with self.assertRaises(self.expect_exception):
170 171
            paddle.fft.fft(paddle.to_tensor(self.x), self.n, self.axis,
                           self.norm)
172 173 174


@place(DEVICES)
175 176 177 178 179 180 181 182 183 184 185 186
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
    ('test_x_complex128', rand_x(5, complex=True), None, (0, 1), 'backward'),
    ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (6, 6),
     (0, 1), 'backward'),
    ('test_n_smaller_than_input_length', rand_x(5, min_dim_len=5, complex=True),
     (4, 4), (0, 1), 'backward'),
    ('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
    ('test_axis_none', rand_x(5), None, None, 'backward'),
    ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
    ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
187
class TestFft2(unittest.TestCase):
188

189 190 191
    def test_fft2(self):
        """Test fft2 with norm condition
        """
192 193
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
194 195 196 197 198 199
                np.allclose(scipy.fft.fft2(self.x, self.n, self.axis,
                                           self.norm),
                            paddle.fft.fft2(paddle.to_tensor(self.x), self.n,
                                            self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
200 201 202 203 204


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
205 206 207 208 209 210 211 212 213
    [('test_x_complex_input', rand_x(2, complex=True), None,
      (0, 1), None, ValueError),
     ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, ValueError),
     ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
     ('test_n_len_not_equal_axis', rand_x(5, max_dim_len=5), 11,
      (0, 1), 'backward', ValueError),
     ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError),
     ('test_axis_out_of_range', rand_x(2), None,
      (0, 1, 2), 'backward', ValueError),
214 215 216 217
     ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
     ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', ValueError),
     ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestFft2Exception(unittest.TestCase):
218

219
    def test_fft2(self):
220 221 222 223 224 225 226 227 228
        """Test fft2 with buoudary condition
        Test case include:
        - input type error
        - input dim error
        - n out of range
        - axis out of range
        - axis type error
        - norm out of range
        """
229 230
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
231 232
                paddle.fft.fft2(paddle.to_tensor(self.x), self.n, self.axis,
                                self.norm)
233 234 235 236 237 238


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
    [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
239 240 241 242 243 244 245
     ('test_x_complex128', rand_x(5, complex=True), None, None, 'backward'),
     ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (6, 6),
      (1, 2), 'backward'),
     ('test_n_smaller_input_length', rand_x(5, min_dim_len=5, complex=True),
      (3, 3), (1, 2), 'backward'),
     ('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'),
     ('test_norm_forward', rand_x(5), None, None, 'forward'),
246 247
     ('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestFftn(unittest.TestCase):
248

249 250 251
    def test_fftn(self):
        """Test fftn with norm condition
        """
252
        with paddle.fluid.dygraph.guard(self.place):
253 254 255 256 257 258 259
            np.testing.assert_allclose(scipy.fft.fftn(self.x, self.n, self.axis,
                                                      self.norm),
                                       paddle.fft.fftn(paddle.to_tensor(self.x),
                                                       self.n, self.axis,
                                                       self.norm),
                                       rtol=RTOL.get(str(self.x.dtype)),
                                       atol=ATOL.get(str(self.x.dtype)))
260 261


262 263 264 265
@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
    [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
266 267 268 269 270 271 272
     ('test_x_complex128', rand_x(5, complex=True), None, None, 'backward'),
     ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (6, 6),
      (1, 2), 'backward'),
     ('test_n_smaller_input_length', rand_x(5, min_dim_len=5, complex=True),
      (3, 3), (1, 2), 'backward'),
     ('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'),
     ('test_norm_forward', rand_x(5), None, None, 'forward'),
273 274
     ('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIFftn(unittest.TestCase):
275

276 277 278 279 280 281
    def test_ifftn(self):
        """Test ifftn with norm condition
        """
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.ifftn(self.x, self.n, self.axis, self.norm),
282 283
                paddle.fft.ifftn(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm),
284 285 286 287
                rtol=RTOL.get(str(self.x.dtype)),
                atol=ATOL.get(str(self.x.dtype)))


288 289 290
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_complex128',
291 292 293 294 295 296 297 298 299 300 301 302
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, -1, "backward"),
    ('test_n_grater_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), 4, -1, "backward"),
    ('test_n_smaller_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), 2, -1, "backward"),
    ('test_axis_not_last', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, 1, "backward"),
    ('test_norm_forward', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, 1, "forward"),
    ('test_norm_ortho', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, -1, "ortho"),
303 304
])
class TestHfft(unittest.TestCase):
305

306
    def test_hfft(self):
307 308
        """Test hfft with norm condition
        """
309
        with paddle.fluid.dygraph.guard(self.place):
310 311 312 313 314 315 316
            np.testing.assert_allclose(scipy.fft.hfft(self.x, self.n, self.axis,
                                                      self.norm),
                                       paddle.fft.hfft(paddle.to_tensor(self.x),
                                                       self.n, self.axis,
                                                       self.norm),
                                       rtol=1e-5,
                                       atol=0)
317 318 319 320 321


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_complex128',
322 323 324 325 326 327 328 329 330 331 332 333
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, -1, "backward"),
    ('test_n_grater_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), 4, -1, "backward"),
    ('test_n_smaller_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), 2, -1, "backward"),
    ('test_axis_not_last', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, -1, "backward"),
    ('test_norm_forward', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, -1, "forward"),
    ('test_norm_ortho', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, -1, "ortho"),
334 335
])
class TestIrfft(unittest.TestCase):
336

337
    def test_irfft(self):
338 339
        """Test irfft with norm condition
        """
340 341 342
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.irfft(self.x, self.n, self.axis, self.norm),
343 344
                paddle.fft.irfft(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm),
345 346 347 348 349 350 351
                rtol=1e-5,
                atol=0)


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_complex128',
352 353 354 355 356 357 358 359 360 361 362 363
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, None, "backward"),
    ('test_n_grater_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), [4], None, "backward"),
    ('test_n_smaller_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), [2], None, "backward"),
    ('test_axis_not_last', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "backward"),
    ('test_norm_forward', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "forward"),
    ('test_norm_ortho', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "ortho"),
364
])
365
class TestIrfftn(unittest.TestCase):
366

367
    def test_irfftn(self):
368 369
        """Test irfftn with norm condition
        """
370 371 372
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.irfftn(self.x, self.n, self.axis, self.norm),
373 374
                paddle.fft.irfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm),
375 376 377 378 379 380 381
                rtol=1e-5,
                atol=0)


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_complex128',
382 383 384 385 386 387 388 389 390 391 392 393
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, None, "backward"),
    ('test_n_grater_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), [4], None, "backward"),
    ('test_n_smaller_than_input_length', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), [2], None, "backward"),
    ('test_axis_not_last', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "backward"),
    ('test_norm_forward', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "forward"),
    ('test_norm_ortho', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), None, None, "ortho"),
394
])
395
class TestHfftn(unittest.TestCase):
396

397
    def test_hfftn(self):
398 399
        """Test hfftn with norm condition
        """
400 401 402
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.hfftn(self.x, self.n, self.axis, self.norm),
403 404
                paddle.fft.hfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm),
405 406 407 408 409 410 411
                rtol=1e-5,
                atol=0)


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
    ('test_x_complex128',
412 413
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, (-2, -1), "backward"),
414 415 416
    ('test_with_s', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
     [2, 2], (-2, -1), "backward", ValueError),
    ('test_axis_not_last',
417 418
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "backward"),
419
    ('test_norm_forward',
420 421
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "forward"),
422
    ('test_norm_ortho',
423 424
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "ortho"),
425
])
426
class TestHfft2(unittest.TestCase):
427

428
    def test_hfft2(self):
429 430
        """Test hfft2 with norm condition
        """
431 432 433
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.hfft2(self.x, self.s, self.axis, self.norm),
434 435
                paddle.fft.hfft2(paddle.to_tensor(self.x), self.s, self.axis,
                                 self.norm),
436 437 438 439 440 441 442
                rtol=1e-5,
                atol=0)


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 's', 'axis', 'norm'), [
    ('test_x_complex128',
443 444
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
         np.complex128), None, (-2, -1), "backward"),
445
    ('test_n_equal_input_length',
446 447
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (4, 6),
     (-2, -1), "backward"),
448
    ('test_axis_not_last',
449 450
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "backward"),
451
    ('test_norm_forward',
452 453
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "forward"),
454
    ('test_norm_ortho',
455 456
     np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), None,
     (-2, -1), "ortho"),
457 458
])
class TestIrfft2(unittest.TestCase):
459

460
    def test_irfft2(self):
461 462
        """Test irfft2 with norm condition
        """
463 464 465
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.irfft2(self.x, self.s, self.axis, self.norm),
466 467
                paddle.fft.irfft2(paddle.to_tensor(self.x), self.s, self.axis,
                                  self.norm),
468 469 470 471 472
                rtol=1e-5,
                atol=0)


@place(DEVICES)
473 474 475
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
    ('test_bool_input',
     (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
476
         np.bool_), None, -1, 'backward', NotImplementedError),
477 478 479 480 481 482 483 484 485 486 487 488 489
    ('test_n_nagative', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), -1, -1, 'backward', ValueError),
    ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
     'backward', ValueError),
    ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
     (1, 2, 3), -1, 'backward', ValueError),
    ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
     None, 10, 'backward', ValueError),
    ('test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), None,
     (0, 1), 'backward', ValueError),
    ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
     1j * np.random.randn(4, 4), None, -1, 'random', ValueError)
])
490
class TestHfftException(unittest.TestCase):
491

492
    def test_hfft(self):
493 494 495 496 497 498 499 500 501
        """Test hfft with buoudary condition
        Test case include:
        Test case include:
        - n out of range
        - n type error
        - axis out of range
        - axis type error
        - norm out of range
        """
502 503
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
504 505
                paddle.fft.hfft(paddle.to_tensor(self.x), self.n, self.axis,
                                self.norm)
506 507 508


@place(DEVICES)
509 510 511 512 513 514 515 516 517 518 519 520 521 522
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
    ('test_n_nagative', np.random.randn(4, 4, 4) +
     1j * np.random.randn(4, 4, 4), -1, -1, 'backward', ValueError),
    ('test_n_zero', np.random.randn(4, 4) + 1j * np.random.randn(4, 4), 0, -1,
     'backward', ValueError),
    ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
     (1, 2), -1, 'backward', ValueError),
    ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
     None, 10, 'backward', ValueError),
    ('test_axis_with_array', np.random.randn(4) + 1j * np.random.randn(4), None,
     (0, 1), 'backward', ValueError),
    ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
     1j * np.random.randn(4, 4), None, None, 'random', ValueError)
])
523
class TestIrfftException(unittest.TestCase):
524

525
    def test_irfft(self):
526 527 528 529 530 531 532 533 534
        """
        Test irfft with buoudary condition
        Test case include:
        - n out of range
        - n type error
        - axis type error
        - axis out of range
        - norm out of range
        """
535 536
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
537 538
                paddle.fft.irfft(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
539 540 541 542 543 544


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_bool_input',
545
      (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
546
          np.bool_), None, (-2, -1), 'backward', NotImplementedError),
547 548 549 550 551 552 553 554
     ('test_n_nagative',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
      (-2, -1), 'backward', ValueError),
     ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      (0, 0), (-2, -1), 'backward', ValueError),
     ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      3, None, 'backward', ValueError),
     ('test_n_axis_dim',
555 556 557 558
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2),
      (-1), 'backward', ValueError),
     ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
      None, (1, 2), 'backward', ValueError),
559
     ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, -1,
560 561 562
      'backward', ValueError),
     ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
      1j * np.random.randn(4, 4), None, None, 'random', ValueError)])
563
class TestHfft2Exception(unittest.TestCase):
564

565
    def test_hfft2(self):
566 567 568 569 570 571 572 573 574 575
        """
        Test hfft2 with buoudary condition
        Test case include:
        - input type error
        - n type error
        - n out of range
        - axis out of range
        - the dimensions of n and axis are different
        - norm out of range
        """
576 577
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
578 579
                paddle.fft.hfft2(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
580 581 582 583 584 585 586 587 588


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_n_nagative',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
      (-2, -1), 'backward', ValueError),
     ('test_zero_point',
589 590
      np.random.randn(4, 4, 1) + 1j * np.random.randn(4, 4, 1), None,
      (-2, -1), "backward", ValueError),
591 592 593
     ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      (0, 0), (-2, -1), 'backward', ValueError),
     ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
594 595 596 597
      3, -1, 'backward', ValueError),
     ('test_n_axis_dim',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2),
      (-3, -2, -1), 'backward', ValueError),
598
     ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
599 600 601 602 603
      None, (1, 2), 'backward', ValueError),
     ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1,
      'backward', ValueError),
     ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
      1j * np.random.randn(4, 4), None, None, 'random', ValueError)])
604
class TestIrfft2Exception(unittest.TestCase):
605

606
    def test_irfft2(self):
607 608 609 610 611 612 613 614 615 616
        """
        Test irfft2 with buoudary condition
        Test case include:
        - input type error
        - n type error
        - n out of range
        - axis out of range
        - the dimensions of n and axis are different
        - norm out of range
        """
617 618
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
619 620
                paddle.fft.irfft2(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm)
621 622 623 624 625 626


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_bool_input',
627
      (np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4)).astype(
628
          np.bool_), None, (-2, -1), 'backward', NotImplementedError),
629 630 631 632 633 634 635 636
     ('test_n_nagative',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
      (-2, -1), 'backward', ValueError),
     ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      (0, 0), (-2, -1), 'backward', ValueError),
     ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      3, -1, 'backward', ValueError),
     ('test_n_axis_dim',
637 638 639 640
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2),
      (-3, -2, -1), 'backward', ValueError),
     ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
      None, (10, 20), 'backward', ValueError),
641
     ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1,
642 643 644
      'backward', ValueError),
     ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
      1j * np.random.randn(4, 4), None, None, 'random', ValueError)])
645
class TestHfftnException(unittest.TestCase):
646

647
    def test_hfftn(self):
648 649 650 651 652 653 654 655 656
        """Test hfftn with buoudary condition
        Test case include:
        - input type error
        - n type error
        - n out of range
        - axis out of range
        - the dimensions of n and axis are different
        - norm out of range
        """
657 658
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
659 660
                paddle.fft.hfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
661 662 663 664 665 666 667 668 669 670 671


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_n_nagative',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (-1, -2),
      (-2, -1), 'backward', ValueError),
     ('test_n_zero', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
      (0, 0), (-2, -1), 'backward', ValueError),
     ('test_n_type', np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4),
672 673 674 675
      3, -1, 'backward', ValueError),
     ('test_n_axis_dim',
      np.random.randn(4, 4, 4) + 1j * np.random.randn(4, 4, 4), (1, 2),
      (-3, -2, -1), 'backward', ValueError),
676 677 678
     ('test_axis_out_of_range', np.random.randn(4) + 1j * np.random.randn(4),
      None, (10, 20), 'backward', ValueError),
     ('test_axis_type', np.random.randn(4) + 1j * np.random.randn(4), None, 1,
679 680 681
      'backward', ValueError),
     ('test_norm_not_in_enum_value', np.random.randn(4, 4) +
      1j * np.random.randn(4, 4), None, None, 'random', ValueError)])
682
class TestIrfftnException(unittest.TestCase):
683

684
    def test_irfftn(self):
685 686 687 688 689 690 691 692
        """Test irfftn with buoudary condition
        Test case include:
        - n out of range
        - n type error
        - axis out of range
        - norm out of range
        - the dimensions of n and axis are different
        """
693 694
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
695 696
                paddle.fft.irfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm)
697 698 699


@place(DEVICES)
700 701 702 703 704 705 706 707 708
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
              [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
               ('test_n_grater_than_input_length', rand_x(
                   5, max_dim_len=5), 11, -1, 'backward'),
               ('test_n_smaller_than_input_length', rand_x(
                   5, min_dim_len=5), 3, -1, 'backward'),
               ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
               ('test_norm_forward', rand_x(5), None, 3, 'forward'),
               ('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
709
class TestRfft(unittest.TestCase):
710

711
    def test_rfft(self):
712 713
        """Test rfft with norm condition
        """
714 715
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
716 717 718 719 720 721
                np.allclose(scipy.fft.rfft(self.x, self.n, self.axis,
                                           self.norm),
                            paddle.fft.rfft(paddle.to_tensor(self.x), self.n,
                                            self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
722 723 724


@place(DEVICES)
725 726 727 728 729 730 731 732
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
     ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
     ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
     ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
     ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)]
)
733
class TestRfftException(unittest.TestCase):
734

735
    def test_rfft(self):
736 737 738 739 740 741 742 743
        """Test rfft with buoudary condition
        Test case include:
        - n out of range
        - axis out of range
        - axis type error
        - norm out of range
        - the dimensions of n and axis are different
        """
744
        with self.assertRaises(self.expect_exception):
745 746
            paddle.fft.rfft(paddle.to_tensor(self.x), self.n, self.axis,
                            self.norm)
747 748 749


@place(DEVICES)
750 751 752 753 754 755 756 757 758 759 760
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
    ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (6, 6),
     (0, 1), 'backward'),
    ('test_n_smaller_than_input_length', rand_x(5, min_dim_len=5), (4, 4),
     (0, 1), 'backward'),
    ('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
    ('test_axis_none', rand_x(5), None, None, 'backward'),
    ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
    ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
761
class TestRfft2(unittest.TestCase):
762

763
    def test_rfft2(self):
764 765
        """Test rfft2 with norm condition
        """
766 767
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
768 769 770 771 772 773
                np.allclose(scipy.fft.rfft2(self.x, self.n, self.axis,
                                            self.norm),
                            paddle.fft.rfft2(paddle.to_tensor(self.x), self.n,
                                             self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
774 775 776


@place(DEVICES)
777 778 779 780 781 782 783 784 785 786 787 788
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
    ('test_x_complex_input', rand_x(2, complex=True), None,
     (0, 1), 'backward', RuntimeError),
    ('test_x_1dim_tensor', rand_x(1), None, (0, 1), 'backward', ValueError),
    ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
    ('test_n_zero', rand_x(2), 0, (0, 1), 'backward', ValueError),
    ('test_axis_out_of_range', rand_x(2), None,
     (0, 1, 2), 'backward', ValueError),
    ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
    ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', ValueError),
    ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError),
])
789
class TestRfft2Exception(unittest.TestCase):
790

791 792 793 794 795 796 797 798 799 800
    def test_rfft2(self):
        """Test rfft2 with buoudary condition
        Test case include:
        - input type error
        - input dim error
        - n out of range
        - axis out of range
        - norm out of range
        - the dimensions of n and axis are different
        """
801 802
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
803 804
                paddle.fft.rfft2(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
805 806 807


@place(DEVICES)
808 809 810 811 812 813 814 815 816 817
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
    ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (6, 6),
     (1, 2), 'backward'),
    ('test_n_smaller_input_length', rand_x(5, min_dim_len=5), (3, 3),
     (1, 2), 'backward'),
    ('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'),
    ('test_norm_forward', rand_x(5), None, None, 'forward'),
    ('test_norm_ortho', rand_x(5), None, None, 'ortho'),
])
818
class TestRfftn(unittest.TestCase):
819

820
    def test_rfftn(self):
821 822
        """Test rfftn with norm condition
        """
823 824
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
825 826 827 828 829 830
                np.allclose(scipy.fft.rfftn(self.x, self.n, self.axis,
                                            self.norm),
                            paddle.fft.rfftn(paddle.to_tensor(self.x), self.n,
                                             self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
831 832 833


@place(DEVICES)
834 835 836 837 838 839 840 841 842
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
    ('test_x_complex', rand_x(
        4, complex=True), None, None, 'backward', RuntimeError),
    ('test_n_nagative', rand_x(4), (-1, -1), (1, 2), 'backward', ValueError),
    ('test_n_not_sequence', rand_x(4), -1, None, 'backward', ValueError),
    ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError),
    ('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', ValueError),
    ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)
])
843
class TestRfftnException(unittest.TestCase):
844

845 846 847 848 849 850 851 852
    def test_rfftn(self):
        """Test rfftn with buoudary condition
        Test case include:
        - n out of range
        - axis out of range
        - norm out of range
        - the dimensions of n and axis are different
        """
853 854
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
855 856
                paddle.fft.rfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
857 858 859


@place(DEVICES)
860 861 862 863 864 865 866 867 868
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
              [('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
               ('test_n_grater_than_input_length', rand_x(
                   5, max_dim_len=5), 11, -1, 'backward'),
               ('test_n_smaller_than_input_length', rand_x(
                   5, min_dim_len=5), 3, -1, 'backward'),
               ('test_axis_not_last', rand_x(5), None, 3, 'backward'),
               ('test_norm_forward', rand_x(5), None, 3, 'forward'),
               ('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
869
class TestIhfft(unittest.TestCase):
870

871
    def test_ihfft(self):
872 873
        """Test ihfft with norm condition
        """
874 875 876
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.ihfft(self.x, self.n, self.axis, self.norm),
877 878
                paddle.fft.ihfft(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm),
879 880 881 882 883
                rtol=RTOL.get(str(self.x.dtype)),
                atol=ATOL.get(str(self.x.dtype)))


@place(DEVICES)
884 885 886 887 888 889 890 891
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
    [('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
     ('test_n_zero', rand_x(2), 0, -1, 'backward', ValueError),
     ('test_axis_out_of_range', rand_x(1), None, 10, 'backward', ValueError),
     ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
     ('test_norm_not_in_enum_value', rand_x(2), None, -1, 'random', ValueError)]
)
892
class TestIhfftException(unittest.TestCase):
893

894
    def test_ihfft(self):
895 896 897 898 899 900
        """Test ihfft with buoudary condition
        Test case include:
        - axis type error
        - axis out of range
        - norm out of range
        """
901 902
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
903 904
                paddle.fft.ihfft(paddle.to_tensor(self.x), self.n, self.axis,
                                 self.norm)
905 906 907


@place(DEVICES)
908 909 910 911 912 913 914 915 916 917 918
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
    ('test_x_float64', rand_x(5), None, (0, 1), 'backward'),
    ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (11, 11),
     (0, 1), 'backward'),
    ('test_n_smaller_than_input_length', rand_x(5, min_dim_len=5), (1, 1),
     (0, 1), 'backward'),
    ('test_axis_random', rand_x(5), None, (1, 2), 'backward'),
    ('test_axis_none', rand_x(5), None, None, 'backward'),
    ('test_norm_forward', rand_x(5), None, (0, 1), 'forward'),
    ('test_norm_ortho', rand_x(5), None, (0, 1), 'ortho'),
])
919
class TestIhfft2(unittest.TestCase):
920

921
    def test_ihfft2(self):
922 923
        """Test ihfft2 with norm condition
        """
924 925 926
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.ihfft2(self.x, self.n, self.axis, self.norm),
927 928
                paddle.fft.ihfft2(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm),
929 930 931 932 933 934 935
                rtol=RTOL.get(str(self.x.dtype)),
                atol=ATOL.get(str(self.x.dtype)))


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'),
936 937 938 939 940 941
    [('test_x_complex_input', rand_x(2, complex=True), None,
      (0, 1), None, ValueError),
     ('test_x_1dim_tensor', rand_x(1), None, (0, 1), None, ValueError),
     ('test_n_nagative', rand_x(2), -1, (0, 1), 'backward', ValueError),
     ('test_n_len_not_equal_axis', rand_x(5, max_dim_len=5), 11,
      (0, 1), 'backward', ValueError),
942
     ('test_n_zero', rand_x(2), (0, 0), (0, 1), 'backward', ValueError),
943 944 945 946
     ('test_axis_out_of_range', rand_x(2), None,
      (0, 1, 2), 'backward', ValueError),
     ('test_axis_with_array', rand_x(1), None, (0, 1), 'backward', ValueError),
     ('test_axis_not_sequence', rand_x(5), None, -10, 'backward', ValueError),
947 948
     ('test_norm_not_enum', rand_x(2), None, -1, 'random', ValueError)])
class TestIhfft2Exception(unittest.TestCase):
949

950 951 952 953 954 955 956 957 958 959
    def test_ihfft2(self):
        """Test ihfft2 with buoudary condition
        Test case include:
        - input type error
        - input dim error
        - n out of range
        - axis type error
        - axis out of range
        - norm out of range
        """
960 961
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
962 963
                paddle.fft.ihfft2(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm)
964 965 966 967 968 969


@place(DEVICES)
@parameterize(
    (TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
    [('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
970 971 972 973 974 975
     ('test_n_grater_input_length', rand_x(5, max_dim_len=5), (11, 11),
      (0, 1), 'backward'),
     ('test_n_smaller_input_length', rand_x(5, min_dim_len=5), (1, 1),
      (0, 1), 'backward'),
     ('test_axis_not_default', rand_x(5), None, (1, 2), 'backward'),
     ('test_norm_forward', rand_x(5), None, None, 'forward'),
976 977
     ('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIhfftn(unittest.TestCase):
978

979 980 981
    def test_ihfftn(self):
        """Test ihfftn with norm condition
        """
982 983
        with paddle.fluid.dygraph.guard(self.place):
            self.assertTrue(
984 985 986 987 988 989
                np.allclose(scipy.fft.ihfftn(self.x, self.n, self.axis,
                                             self.norm),
                            paddle.fft.ihfftn(paddle.to_tensor(self.x), self.n,
                                              self.axis, self.norm),
                            rtol=RTOL.get(str(self.x.dtype)),
                            atol=ATOL.get(str(self.x.dtype))))
990 991 992


@place(DEVICES)
993 994
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
    ('test_x_complex', rand_x(
995
        4, complex=True), None, None, 'backward', RuntimeError),
996 997 998 999 1000
    ('test_n_nagative', rand_x(4), -1, None, 'backward', ValueError),
    ('test_n_zero', rand_x(4), 0, None, 'backward', ValueError),
    ('test_axis_out_of_range', rand_x(1), None, [0, 1], 'backward', ValueError),
    ('test_norm_not_in_enum', rand_x(2), None, -1, 'random', ValueError)
])
1001
class TestIhfftnException(unittest.TestCase):
1002

1003 1004 1005 1006 1007 1008 1009 1010
    def test_ihfftn(self):
        """Test ihfftn with buoudary condition
        Test case include:
        - input type error
        - n out of range
        - axis out of range
        - norm out of range
        """
1011 1012
        with paddle.fluid.dygraph.guard(self.place):
            with self.assertRaises(self.expect_exception):
1013 1014
                paddle.fft.ihfftn(paddle.to_tensor(self.x), self.n, self.axis,
                                  self.norm)
1015 1016 1017 1018 1019 1020 1021 1022


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [
    ('test_without_d', 20, 1, 'float32'),
    ('test_with_d', 20, 0.5, 'float32'),
])
class TestFftFreq(unittest.TestCase):
1023

1024
    def test_fftfreq(self):
1025 1026
        """Test fftfreq with norm condition
        """
1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.fftfreq(self.n, self.d).astype(self.dtype),
                paddle.fft.fftfreq(self.n, self.d, self.dtype).numpy(),
                rtol=RTOL.get(str(self.dtype)),
                atol=ATOL.get(str(self.dtype)))


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'n', 'd', 'dtype'), [
    ('test_without_d', 20, 1, 'float32'),
    ('test_with_d', 20, 0.5, 'float32'),
])
class TestRfftFreq(unittest.TestCase):
1041

1042
    def test_rfftfreq(self):
1043 1044
        """Test rfftfreq with norm condition
        """
1045 1046 1047 1048 1049 1050 1051 1052 1053
        with paddle.fluid.dygraph.guard(self.place):
            np.testing.assert_allclose(
                scipy.fft.rfftfreq(self.n, self.d).astype(self.dtype),
                paddle.fft.rfftfreq(self.n, self.d, self.dtype).numpy(),
                rtol=RTOL.get(str(self.dtype)),
                atol=ATOL.get(str(self.dtype)))


@place(DEVICES)
1054 1055 1056 1057 1058 1059 1060
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
    ('test_1d', np.random.randn(10), (0, ), 'float64'),
    ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
    ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
    ('test_2d_odd_with_all_axes',
     np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'),
])
1061
class TestFftShift(unittest.TestCase):
1062

1063
    def test_fftshift(self):
1064 1065
        """Test fftshift with norm condition
        """
1066
        with paddle.fluid.dygraph.guard(self.place):
1067 1068 1069 1070 1071 1072
            np.testing.assert_allclose(scipy.fft.fftshift(self.x, self.axes),
                                       paddle.fft.fftshift(
                                           paddle.to_tensor(self.x),
                                           self.axes).numpy(),
                                       rtol=RTOL.get(str(self.x.dtype)),
                                       atol=ATOL.get(str(self.x.dtype)))
1073 1074 1075


@place(DEVICES)
1076 1077
@parameterize(
    (TEST_CASE_NAME, 'x', 'axes'),
1078 1079
    [('test_1d', np.random.randn(10), (0, ), 'float64'),
     ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
1080 1081 1082
     ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
     ('test_2d_odd_with_all_axes',
      np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')])
1083
class TestIfftShift(unittest.TestCase):
1084

1085
    def test_ifftshift(self):
1086 1087
        """Test ifftshift with norm condition
        """
1088
        with paddle.fluid.dygraph.guard(self.place):
1089 1090 1091 1092 1093 1094
            np.testing.assert_allclose(scipy.fft.ifftshift(self.x, self.axes),
                                       paddle.fft.ifftshift(
                                           paddle.to_tensor(self.x),
                                           self.axes).numpy(),
                                       rtol=RTOL.get(str(self.x.dtype)),
                                       atol=ATOL.get(str(self.x.dtype)))
1095 1096 1097 1098 1099 1100


if __name__ == '__main__':
    unittest.main()

# yapf: enable