未验证 提交 b8e6ca92 编写于 作者: 张春乔 提交者: GitHub

fix div 0 error of ftt.rfftfreq (#49955)

上级 e12c9221
......@@ -1326,6 +1326,8 @@ def rfftfreq(n, d=1.0, dtype=None, name=None):
# [0. , 0.66666669, 1.33333337])
"""
if d * n == 0:
raise ValueError("d or n should not be 0.")
dtype = paddle.framework.get_default_dtype()
val = 1.0 / (n * d)
......
......@@ -1860,6 +1860,23 @@ class TestRfftFreq(unittest.TestCase):
)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'n', 'd', 'dtype', 'expect_exception'),
[
('test_with_0_0', 0, 0, 'float32', ValueError),
('test_with_n_0', 20, 0, 'float32', ValueError),
('test_with_0_d', 0, 20, 'float32', ValueError),
],
)
class TestRfftFreqException(unittest.TestCase):
def test_rfftfreq2(self):
"""Test fftfreq with d = 0"""
with paddle.fluid.dygraph.guard(self.place):
with self.assertRaises(self.expect_exception):
paddle.fft.rfftfreq(self.n, self.d, self.dtype)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'dtype'),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册