From b8e6ca92274b8540cdb799fbb787b6bb7a44111c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 6 Feb 2023 15:19:25 +0800 Subject: [PATCH] fix div 0 error of ftt.rfftfreq (#49955) --- python/paddle/fft.py | 2 ++ .../fluid/tests/unittests/fft/test_fft.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fft.py b/python/paddle/fft.py index b8939d08b58..96f743287ce 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1326,6 +1326,8 @@ def rfftfreq(n, d=1.0, dtype=None, name=None): # [0. , 0.66666669, 1.33333337]) """ + if d * n == 0: + raise ValueError("d or n should not be 0.") dtype = paddle.framework.get_default_dtype() val = 1.0 / (n * d) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 8a57fa81b57..b4f92700104 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1860,6 +1860,23 @@ class TestRfftFreq(unittest.TestCase): ) +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'n', 'd', 'dtype', 'expect_exception'), + [ + ('test_with_0_0', 0, 0, 'float32', ValueError), + ('test_with_n_0', 20, 0, 'float32', ValueError), + ('test_with_0_d', 0, 20, 'float32', ValueError), + ], +) +class TestRfftFreqException(unittest.TestCase): + def test_rfftfreq2(self): + """Test fftfreq with d = 0""" + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.rfftfreq(self.n, self.d, self.dtype) + + @place(DEVICES) @parameterize( (TEST_CASE_NAME, 'x', 'axes', 'dtype'), -- GitLab