未验证 提交 35fa30d0 编写于 作者: W wanghuancoder 提交者: GitHub

[Bug fix] Fix kaiming initializer div zero (#49656)

* fix kaiming initializer div zero
上级 6f20a383
......@@ -783,6 +783,16 @@ class MSRAInitializer(Initializer):
# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in
if fan_in == 0:
if self._fan_in is None:
raise ValueError(
"The in_features of the Tensor contain zero, can not initialize the Tensor."
)
else:
raise ValueError(
"fan_in should not be zero, can not initialize the Tensor."
)
if self._seed == 0:
self._seed = block.program.random_seed
......
......@@ -1176,6 +1176,22 @@ class TestDiracInitializer3(TestDiracInitializer1):
paddle.nn.Conv2D(5, 9, (3, 3), weight_attr=self.weight_attr)
class TestKaimingUniform(unittest.TestCase):
def func_kaiminguniform_initializer_fan_in_zero(self):
paddle.enable_static()
x = paddle.static.data(name='x', shape=[1, 0, 0], dtype='float32')
kaiming = paddle.nn.initializer.KaimingUniform(0)
param_attr = paddle.ParamAttr(initializer=kaiming)
paddle.static.nn.prelu(x, 'all', param_attr=param_attr)
def test_type_error(self):
self.assertRaises(
ValueError, self.func_kaiminguniform_initializer_fan_in_zero
)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册