diff --git a/paddle/fluid/operators/roll_op.cc b/paddle/fluid/operators/roll_op.cc index b74dfc984affb2b003bdfce84eb8493738887308..f82510556fde87fbf4aeb1904e29325358598791 100644 --- a/paddle/fluid/operators/roll_op.cc +++ b/paddle/fluid/operators/roll_op.cc @@ -183,7 +183,12 @@ REGISTER_OP_VERSION(roll) "(std::vector) Axis along which to roll. " "It must have the same size with shifts, or size = 0.", std::vector()) - .DeleteAttr( - "dims", - "(std::vector) Dims along which to roll. " - "It must have the same size with shifts, or size = 0.")); + .DeleteAttr("dims", + "(std::vector) Dims along which to roll. " + "It must have the same size with shifts, or size = 0.")) + .AddCheckpoint( + R"ROC(Upgrade roll add a dispensable input "ShiftsTensor".)ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "ShiftsTensor", + "The number of places by which the elements of" + "the tensor are shifted.")); diff --git a/python/paddle/fft.py b/python/paddle/fft.py index de15eba0feffaa89004359f12703d7f142f34ff5..7399ccc1ace59527c9067039a986ccf18562c635 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1300,13 +1300,13 @@ def fftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = shape[axes] // 2 else: - shifts = [shape[ax] // 2 for ax in axes] + shifts = paddle.concat([shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1343,13 +1343,13 @@ def ifftshift(x, axes=None, name=None): shape = paddle.shape(x) if axes is None: # shift all axes - rank = paddle.rank(x).reshape([1]) - axes = axes or paddle.arange(0, rank) - shifts = [-size // 2 for size in shape] + rank = len(x.shape) + axes = list(range(0, rank)) + shifts = shape // 2 elif isinstance(axes, int): shifts = -shape[axes] // 2 else: - shifts = [-shape[ax] // 2 for ax in axes] + shifts = paddle.concat([-shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index c83c943217d4e6cba264a63154bfa97091bc66a5..604de11521b7d6a923335637cd2e59aa9f3c4cb7 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1009,10 +1009,11 @@ class TestRfftFreq(unittest.TestCase): @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [ - ('test_1d', np.random.randn(10), (0, ), 'float64'), - ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), -]) +@parameterize( + (TEST_CASE_NAME, 'x', 'axes', 'dtype'), + [('test_1d', np.random.randn(10), (0, ), 'float64'), + ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')]) class TestFftShift(unittest.TestCase): def test_fftshift(self): """Test fftshift with norm condition @@ -1030,6 +1031,7 @@ class TestFftShift(unittest.TestCase): @parameterize((TEST_CASE_NAME, 'x', 'axes'), [ ('test_1d', np.random.randn(10), (0, ), 'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'), + ('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'), ]) class TestIfftShift(unittest.TestCase): def test_ifftshift(self):