From 34b6860ea36f90f0440a620be1de80c8d154d604 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Wed, 27 Oct 2021 11:39:46 +0800 Subject: [PATCH] fix fftshift/ifftshift on static mode (#36748) * fix fftshift/ifftshift on static mode * update roll_op version * add more test cases for fftshift/ifftshift --- paddle/fluid/operators/roll_op.cc | 13 +++++++++---- python/paddle/fft.py | 16 ++++++++-------- .../paddle/fluid/tests/unittests/fft/test_fft.py | 10 ++++++---- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b74dfc984af..f82510556fd 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -183,7 +183,12 @@ REGISTER_OP_VERSION(roll) "(std::vector) Axis along which to roll. " "It must have the same size with shifts, or size = 0.", std::vector()) - .DeleteAttr( - "dims", - "(std::vector) Dims along which to roll. " - "It must have the same size with shifts, or size = 0.")); + .DeleteAttr("dims", + "(std::vector) Dims along which to roll. " + "It must have the same size with shifts, or size = 0.")) + .AddCheckpoint( + R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "ShiftsTensor", + "The number of places by which the elements of" + "the tensor are shifted.")); diff --git a/python/paddle/fft.py b/python/paddle/fft.py index de15eba0fef..7399ccc1ace 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1300,13 +1300,13 @@ def fftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = shape[axes] // 2 else: - shifts = [shape[ax] // 2 for ax in axes] + shifts = paddle.concat([shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1343,13 +1343,13 @@ def ifftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [-size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: - shifts = [-shape[ax] // 2 for ax in axes] + shifts = paddle.concat([-shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index c83c943217d..604de11521b 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,10 +1009,11 @@ class TestRfftFreq(unittest.TestCase): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'dtype'), + [('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1030,6 +1031,7 @@ class TestFftShift(unittest.TestCase): @parameterize((TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0, ), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), ]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self): -- GitLab