未验证 提交 29ebf621 编写于 作者: F Feiyu Chan 提交者: GitHub

fix typos: is_integer in attribute.py (#37749)

* fix typos: is_integer in attribute.py
* add more test cases for fft
上级 e8c6c7df
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
import paddle import paddle
from .tensor.attribute import is_complex, is_floating_point, is_interger, _real_to_complex_dtype, _complex_to_real_dtype from .tensor.attribute import is_complex, is_floating_point, is_integer, _real_to_complex_dtype, _complex_to_real_dtype
from .fluid.framework import in_dygraph_mode from .fluid.framework import in_dygraph_mode
from . import _C_ops from . import _C_ops
from .fluid.data_feeder import check_variable_and_dtype from .fluid.data_feeder import check_variable_and_dtype
...@@ -196,7 +196,7 @@ def fft(x, n=None, axis=-1, norm="backward", name=None): ...@@ -196,7 +196,7 @@ def fft(x, n=None, axis=-1, norm="backward", name=None):
""" """
if is_interger(x) or is_floating_point(x): if is_integer(x) or is_floating_point(x):
return fft_r2c( return fft_r2c(
x, n, axis, norm, forward=True, onesided=False, name=name) x, n, axis, norm, forward=True, onesided=False, name=name)
else: else:
...@@ -260,7 +260,7 @@ def ifft(x, n=None, axis=-1, norm="backward", name=None): ...@@ -260,7 +260,7 @@ def ifft(x, n=None, axis=-1, norm="backward", name=None):
# 0.14285714+6.25898038e-01j] # 0.14285714+6.25898038e-01j]
""" """
if is_interger(x) or is_floating_point(x): if is_integer(x) or is_floating_point(x):
return fft_r2c( return fft_r2c(
x, n, axis, norm, forward=False, onesided=False, name=name) x, n, axis, norm, forward=False, onesided=False, name=name)
else: else:
...@@ -521,7 +521,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None): ...@@ -521,7 +521,7 @@ def fftn(x, s=None, axes=None, norm="backward", name=None):
# [-8.+0.j 0.+0.j 0.+0.j 0.-0.j] # [-8.+0.j 0.+0.j 0.+0.j 0.-0.j]
# [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]] # [-8.-8.j 0.+0.j 0.+0.j 0.-0.j]]]
""" """
if is_interger(x) or is_floating_point(x): if is_integer(x) or is_floating_point(x):
return fftn_r2c( return fftn_r2c(
x, s, axes, norm, forward=True, onesided=False, name=name) x, s, axes, norm, forward=True, onesided=False, name=name)
else: else:
...@@ -585,7 +585,7 @@ def ifftn(x, s=None, axes=None, norm="backward", name=None): ...@@ -585,7 +585,7 @@ def ifftn(x, s=None, axes=None, norm="backward", name=None):
# [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]] # [ 0.33333333+0.j -0.16666667-0.28867513j -0.16666667+0.28867513j]]
""" """
if is_interger(x) or is_floating_point(x): if is_integer(x) or is_floating_point(x):
return fftn_r2c( return fftn_r2c(
x, s, axes, norm, forward=False, onesided=False, name=name) x, s, axes, norm, forward=False, onesided=False, name=name)
else: else:
...@@ -1355,7 +1355,7 @@ def ifftshift(x, axes=None, name=None): ...@@ -1355,7 +1355,7 @@ def ifftshift(x, axes=None, name=None):
# internal functions # internal functions
def fft_c2c(x, n, axis, norm, forward, name): def fft_c2c(x, n, axis, norm, forward, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x): elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
...@@ -1388,7 +1388,7 @@ def fft_c2c(x, n, axis, norm, forward, name): ...@@ -1388,7 +1388,7 @@ def fft_c2c(x, n, axis, norm, forward, name):
def fft_r2c(x, n, axis, norm, forward, onesided, name): def fft_r2c(x, n, axis, norm, forward, onesided, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, paddle.get_default_dtype()) x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm) _check_normalization(norm)
axis = axis if axis is not None else -1 axis = axis if axis is not None else -1
...@@ -1425,7 +1425,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): ...@@ -1425,7 +1425,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
def fft_c2r(x, n, axis, norm, forward, name): def fft_c2r(x, n, axis, norm, forward, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x): elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
...@@ -1464,7 +1464,7 @@ def fft_c2r(x, n, axis, norm, forward, name): ...@@ -1464,7 +1464,7 @@ def fft_c2r(x, n, axis, norm, forward, name):
def fftn_c2c(x, s, axes, norm, forward, name): def fftn_c2c(x, s, axes, norm, forward, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x): elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
...@@ -1512,7 +1512,7 @@ def fftn_c2c(x, s, axes, norm, forward, name): ...@@ -1512,7 +1512,7 @@ def fftn_c2c(x, s, axes, norm, forward, name):
def fftn_r2c(x, s, axes, norm, forward, onesided, name): def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, paddle.get_default_dtype()) x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm) _check_normalization(norm)
if s is not None: if s is not None:
...@@ -1567,7 +1567,7 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name): ...@@ -1567,7 +1567,7 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
def fftn_c2r(x, s, axes, norm, forward, name): def fftn_c2r(x, s, axes, norm, forward, name):
if is_interger(x): if is_integer(x):
x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype())) x = paddle.cast(x, _real_to_complex_dtype(paddle.get_default_dtype()))
elif is_floating_point(x): elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
......
...@@ -120,6 +120,33 @@ class TestFft(unittest.TestCase): ...@@ -120,6 +120,33 @@ class TestFft(unittest.TestCase):
atol=ATOL.get(str(self.x.dtype)))) atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, -1, 'backward'),
('test_x_complex', rand_x(
5, complex=True), None, -1,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), 11, -1,
'backward'), ('test_n_smaller_than_input_length', rand_x(
5, min_dim_len=5, complex=True), 3, -1, 'backward'),
('test_axis_not_last', rand_x(5), None, 3, 'backward'),
('test_norm_forward', rand_x(5), None, 3, 'forward'),
('test_norm_ortho', rand_x(5), None, 3, 'ortho')])
class TestIfft(unittest.TestCase):
def test_fft(self):
"""Test ifft with norm condition
"""
with paddle.fluid.dygraph.guard(self.place):
self.assertTrue(
np.allclose(
scipy.fft.ifft(self.x, self.n, self.axis, self.norm),
paddle.fft.ifft(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype))))
@place(DEVICES) @place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [ @parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm', 'expect_exception'), [
('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError), ('test_n_nagative', rand_x(2), -1, -1, 'backward', ValueError),
...@@ -230,6 +257,32 @@ class TestFftn(unittest.TestCase): ...@@ -230,6 +257,32 @@ class TestFftn(unittest.TestCase):
atol=ATOL.get(str(self.x.dtype))) atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'),
[('test_x_float64', rand_x(5, np.float64), None, None, 'backward'),
('test_x_complex128', rand_x(
5, complex=True), None, None,
'backward'), ('test_n_grater_input_length', rand_x(
5, max_dim_len=5), (6, 6), (1, 2), 'backward'), (
'test_n_smaller_input_length', rand_x(
5, min_dim_len=5, complex=True), (3, 3), (1, 2), 'backward'),
('test_axis_not_default', rand_x(5), None, (1, 2),
'backward'), ('test_norm_forward', rand_x(5), None, None, 'forward'),
('test_norm_ortho', rand_x(5), None, None, 'ortho')])
class TestIFftn(unittest.TestCase):
def test_ifftn(self):
"""Test ifftn with norm condition
"""
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
scipy.fft.ifftn(self.x, self.n, self.axis, self.norm),
paddle.fft.ifftn(
paddle.to_tensor(self.x), self.n, self.axis, self.norm),
rtol=RTOL.get(str(self.x.dtype)),
atol=ATOL.get(str(self.x.dtype)))
@place(DEVICES) @place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [ @parameterize((TEST_CASE_NAME, 'x', 'n', 'axis', 'norm'), [
('test_x_complex128', ('test_x_complex128',
......
...@@ -60,7 +60,7 @@ def is_floating_point(x): ...@@ -60,7 +60,7 @@ def is_floating_point(x):
return is_fp_dtype return is_fp_dtype
def is_interger(x): def is_integer(x):
dtype = x.dtype dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or dtype == core.VarDesc.VarType.INT8 or
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册