From f3ee5c999bcfaa0f10554df8090638bf0812edf3 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Fri, 29 Oct 2021 12:17:53 +0800 Subject: [PATCH] 1. fix ifftshift(missing negative sign before shifts); (#36834) 2. add complex data type support for paddle.shape at graph assembly. --- python/paddle/fft.py | 2 +- python/paddle/fluid/layers/nn.py | 7 +-- .../fluid/tests/unittests/fft/test_fft.py | 24 +++++---- .../fft/test_fft_with_static_graph.py | 50 +++++++++++++++++++ 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 7399ccc1ac..a62e502203 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1345,7 +1345,7 @@ def ifftshift(x, axes=None, name=None): # shift all axes rank = len(x.shape) axes = list(range(0, rank)) - shifts = shape // 2 + shifts = -shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ceda304b26..dd0abd212e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11396,9 +11396,10 @@ def shape(input): res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) print(res) # [array([ 3, 100, 100], dtype=int32)] """ - check_variable_and_dtype( - input, 'input', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape') + check_variable_and_dtype(input, 'input', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64', + 'complex128' + ], 'shape') helper = LayerHelper('shape', **locals()) out = helper.create_variable_for_type_inference(dtype='int32') helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 604de11521..0ef7a1e939 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,11 +1009,13 @@ class TestRfftFreq(unittest.TestCase): @place(DEVICES) -@parameterize( - (TEST_CASE_NAME, 'x', 'axes', 'dtype'), - [('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), - ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'), +]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1028,11 +1030,13 @@ class TestFftShift(unittest.TestCase): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), - ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes'), + [('test_1d', np.random.randn(10), (0, ), + 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self): """Test ifftshift with norm condition diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py index ac9d1557b5..4f19cd06a4 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft_with_static_graph.py @@ -888,6 +888,56 @@ class TestIhfftnException(unittest.TestCase): pass +@place(DEVICES) +@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ + ('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'), +]) +class TestFftShift(unittest.TestCase): + def test_fftshift(self): + """Test fftshift with norm condition + """ + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = paddle.fft.fftshift(input, axes) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes'), + [('test_1d', np.random.randn(10), (0, ), + 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), + ('test_2d_odd_with_all_axes', + np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')]) +class TestIfftShift(unittest.TestCase): + def test_ifftshift(self): + """Test ifftshift with norm condition + """ + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = paddle.fft.ifftshift(input, axes) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + if __name__ == '__main__': unittest.main() -- GitLab