未验证 提交 2bf82e75 编写于 作者: F Feiyu Chan 提交者: GitHub

fix fft axis (#36321)

fix: `-1` is used when fft's axis is `0`
上级 ea76457c
......@@ -1340,7 +1340,7 @@ def fft_c2c(x, n, axis, norm, forward, name):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
axis = axis or -1
axis = axis if axis is not None else -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
......@@ -1370,7 +1370,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
if is_interger(x):
x = paddle.cast(x, paddle.get_default_dtype())
_check_normalization(norm)
axis = axis or -1
axis = axis if axis is not None else -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
......@@ -1409,7 +1409,7 @@ def fft_c2r(x, n, axis, norm, forward, name):
elif is_floating_point(x):
x = paddle.cast(x, _real_to_complex_dtype(x.dtype))
_check_normalization(norm)
axis = axis or -1
axis = axis if axis is not None else -1
_check_fft_axis(x, axis)
axes = [axis]
axes = _normalize_axes(x, axes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册