diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 7399ccc1ace59527c9067039a986ccf18562c635..a62e502203b631d22e81558a4f9c32bb9b3cfbae 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1345,7 +1345,7 @@ def ifftshift(x, axes=None, name=None): # shift all axes rank = len(x.shape) axes = list(range(0, rank)) - shifts = shape // 2 + shifts = -shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ceda304b26e89583a008c605bc6fdf1216903158..dd0abd212e8342b7a4f9d15f86b529368a401870 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11396,9 +11396,10 @@ def shape(input): res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) print(res) # [array([ 3, 100, 100], dtype=int32)] """ - check_variable_and_dtype( - input, 'input', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape') + check_variable_and_dtype(input, 'input', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], 'shape') helper = LayerHelper('shape', **locals()) out = helper.create_variable_for_type_inference(dtype='int32') helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 604de11521b7d6a923335637cd2e59aa9f3c4cb7..0ef7a1e939e0220a69d1c46ef7a530d9ffa54adc 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,11 +1009,13 @@ class TestRfftFreq(unittest.TestCase): @place(DEVICES) -@parameterize( - (TEST_CASE_NAME, 'x', 'axes', 'dtype'), - [('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), - ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'), +]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1028,11 +1030,13 @@ class TestFftShift(unittest.TestCase): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), - ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes'), + [('test_1d', np.random.randn(10), (0, ), + 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self): """Test ifftshift with norm condition diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py index ac9d1557b53e9da9f644901f122f6660b05255b7..4f19cd06a493fc71935ea2d1cdb23f9d80c8ab46 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -888,6 +888,56 @@ class TestIhfftnException(unittest.TestCase): pass +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'), +]) +class TestFftShift(unittest.TestCase): + def test_fftshift(self): + """Test fftshift with norm condition + """ + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = paddle.fft.fftshift(input, axes) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes'), + [('test_1d', np.random.randn(10), (0, ), + 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')]) +class TestIfftShift(unittest.TestCase): + def test_ifftshift(self): + """Test ifftshift with norm condition + """ + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = paddle.fft.ifftshift(input, axes) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + if __name__ == '__main__': unittest.main()