提交 104c82b1 编写于 作者: H huangjun12 提交者: chajchaj

fix lrn bug when shape=0

上级 40bd7a7a
...@@ -323,6 +323,12 @@ class TestLocalResponseNormFAPIError(unittest.TestCase): ...@@ -323,6 +323,12 @@ class TestLocalResponseNormFAPIError(unittest.TestCase):
self.assertRaises(ValueError, test_dim) self.assertRaises(ValueError, test_dim)
def test_shape():
x = fluid.data(name='x', shape=[0, 0, 2, 3], dtype="float32")
paddle.nn.functional.local_response_norm(x, size=5)
self.assertRaises(ValueError, test_shape)
class TestLocalResponseNormCAPI(unittest.TestCase): class TestLocalResponseNormCAPI(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -487,6 +487,12 @@ def local_response_norm(x, ...@@ -487,6 +487,12 @@ def local_response_norm(x,
'Expected 3D or higher dimensionality input, but got {} dimensions'. 'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim)) format(dim))
for i, size in enumerate(sizes):
if not size > 0:
raise ValueError("Expected every dim's size to be larger than 0, "
"but the size of the {}-th dim is {}".format(i,
size))
channel_last = True if data_format[-1] == "C" else False channel_last = True if data_format[-1] == "C" else False
from functools import reduce from functools import reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册