提交 c31c1c80 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3530 Fix a bug for Parameter

Merge pull request !3530 from hewei/fix_parameter_bug_r0.6
......@@ -210,7 +210,6 @@ class Parameter:
def set_parameter_data(self, data):
"""Set `default_input` of current `Parameter`."""
self.init_mode = None
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
......@@ -243,7 +242,8 @@ class Parameter:
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if self.init_mode is None:
if isinstance(self.default_input, Tensor):
# skip if data already initialized.
return
if layout is not None:
if not isinstance(layout, list):
......
......@@ -134,3 +134,19 @@ def test_check_str_by_regular():
_check_str_by_regular(str5)
with pytest.raises(ValueError):
_check_str_by_regular(str6)
def test_parameter_lazy_init():
# Call init_data() without set default_input.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
assert not isinstance(para.default_input, Tensor)
para.init_data()
assert isinstance(para.default_input, Tensor)
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
# Call init_data() after default_input is set.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
assert not isinstance(para.default_input, Tensor)
para.default_input = Tensor(np.zeros((1, 2, 3)))
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
para.init_data() # expect no effect.
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册