未验证 提交 29ebf621 编写于 作者: F Feiyu Chan 提交者: GitHub

fix typos: is_integer in attribute.py (#37749)

* fix typos: is_integer in attribute.py
* add more test cases for fft
上级 e8c6c7df
......@@ -15,7 +15,7 @@
from typing import Sequence
import numpy as np
import paddle
from .tensor.attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype
from .tensor.attribute import is_complex, is_floating_point, is_integer, _real_to_complex_dtype, _complex_to_real_dtype
from .fluid.framework import in_dygraph_mode
from . import _C_ops
from .fluid.data_feeder import check_variable_and_dtype
......@@ -196,7 +196,7 @@ def fft(x, n=None, axis=-1, norm="backward", name=None):
"""
if is_interger(x) or is_floating_point(x):
if is_integer(x) or is_floating_point(x):
return fft_r2c(
x, n, axis, norm, forward=True, onesided=False, name=name)
else:
......@@ -260,7 +260,7 @@ def ifft(x, n=None, axis=-1, norm="backward", name=None):
# 0.14285714+6.25898038e-01j]
"""
if is_interger(x) or is_floating_point(x):
if is_integer(x) or is_floating_point(x):
return fft_r2c(
x, n, axis, norm, forward=False, onesided=False, name=name)
else:
......@@ -521,7 +521,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None):
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]]
"""
if is_interger(x) or is_floating_point(x):
if is_integer(x) or is_floating_point(x):
return fftn_r2c(
x, s, axes, norm, forward=True, onesided=False, name=name)
else:
......@@ -585,7 +585,7 @@ def ifftn(x, s=None, axes=None, norm="backward", name=None):
# [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]]
"""
if is_interger(x) or is_floating_point(x):
if is_integer(x) or is_floating_point(x):
return fftn_r2c(
x, s, axes, norm, forward=False, onesided=False, name=name)
else:
......@@ -1355,7 +1355,7 @@ def ifftshift(x, axes=None, name=None):
# internal functions
def fft_c2c(x, n, axis, norm, forward, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
......@@ -1388,7 +1388,7 @@ def fft_c2c(x, n, axis, norm, forward, name):
def fft_r2c(x, n, axis, norm, forward, onesided, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
axis = axis if axis is not None else -1
......@@ -1425,7 +1425,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
def fft_c2r(x, n, axis, norm, forward, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
......@@ -1464,7 +1464,7 @@ def fft_c2r(x, n, axis, norm, forward, name):
def fftn_c2c(x, s, axes, norm, forward, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
......@@ -1512,7 +1512,7 @@ def fftn_c2c(x, s, axes, norm, forward, name):
def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
if s is not None:
......@@ -1567,7 +1567,7 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
def fftn_c2r(x, s, axes, norm, forward, name):
if is_interger(x):
if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
......
......@@ -120,6 +120,33 @@ class TestFft(unittest.TestCase):
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
('test_x_complex', rand_x(
5, complex=True), None, -1,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), 11, -1,
'backward'), ('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5, complex=True), 3, -1, 'backward'),
('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestIfft(unittest.TestCase):
def test_fft(self):
"""Test ifft with norm condition
"""
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.ifft(self.x, self.n, self.axis, self.norm),
paddle.fft.ifft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
......@@ -230,6 +257,32 @@ class TestFftn(unittest.TestCase):
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_x_complex128', rand_x(
5, complex=True), None, None,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2), 'backward'), (
'test_n_smaller_input_length', rand_x(
5, min_dim_len=5, complex=True), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIFftn(unittest.TestCase):
def test_ifftn(self):
"""Test ifftn with norm condition
"""
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ifftn(self.x, self.n, self.axis, self.norm),
paddle.fft.ifftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128',
......
......@@ -60,7 +60,7 @@ def is_floating_point(x):
return is_fp_dtype
def is_interger(x):
def is_integer(x):
dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册