未验证 提交 fa7aa6b8 编写于 作者: F Feiyu Chan 提交者: GitHub

1. fix ifftshift(missing negative sign before shifts); (#36835)

2. add complex data type support for paddle.shape at graph assembly.
上级 c716cf35
......@@ -1345,7 +1345,7 @@ def ifftshift(x, axes=None, name=None):
# shift all axes
rank = len(x.shape)
axes = list(range(0, rank))
shifts = shape // 2
shifts = -shape // 2
elif isinstance(axes, int):
shifts = -shape[axes] // 2
else:
......
......@@ -11396,9 +11396,10 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
check_variable_and_dtype(
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape')
check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
], 'shape')
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
......
......@@ -1009,11 +1009,13 @@ class TestRfftFreq(unittest.TestCase):
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'dtype'),
[('test_1d', np.random.randn(10), (0, ), 'float64'),
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')])
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'),
])
class TestFftShift(unittest.TestCase):
def test_fftshift(self):
"""Test fftshift with norm condition
......@@ -1028,11 +1030,13 @@ class TestFftShift(unittest.TestCase):
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
@parameterize(
(TEST_CASE_NAME, 'x', 'axes'),
[('test_1d', np.random.randn(10), (0, ),
'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
])
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')])
class TestIfftShift(unittest.TestCase):
def test_ifftshift(self):
"""Test ifftshift with norm condition
......
......@@ -888,6 +888,56 @@ class TestIhfftnException(unittest.TestCase):
pass
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'),
])
class TestFftShift(unittest.TestCase):
def test_fftshift(self):
"""Test fftshift with norm condition
"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.fftshift(input, axes)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes'),
[('test_1d', np.random.randn(10), (0, ),
'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')])
class TestIfftShift(unittest.TestCase):
def test_ifftshift(self):
"""Test ifftshift with norm condition
"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.ifftshift(input, axes)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册