From 2bf82e7598bb319e6b959eb58579d39535c999e7 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Mon, 11 Oct 2021 11:24:40 +0800 Subject: [PATCH] fix fft axis (#36321) fix: `-1` is used when fft's axis is `0` --- python/paddle/tensor/fft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/fft.py b/python/paddle/tensor/fft.py index 829399d14ea..f7990e3f891 100644 --- a/python/paddle/tensor/fft.py +++ b/python/paddle/tensor/fft.py @@ -1340,7 +1340,7 @@ def fft_c2c(x, n, axis, norm, forward, name): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1370,7 +1370,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): if is_interger(x): x = paddle.cast(x, paddle.get_default_dtype()) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) @@ -1409,7 +1409,7 @@ def fft_c2r(x, n, axis, norm, forward, name): elif is_floating_point(x): x = paddle.cast(x, _real_to_complex_dtype(x.dtype)) _check_normalization(norm) - axis = axis or -1 + axis = axis if axis is not None else -1 _check_fft_axis(x, axis) axes = [axis] axes = _normalize_axes(x, axes) -- GitLab