diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 4f45f73e88b03c25c7874183874ca360b0ac2140..1a79f56526204773cf6fa313da52d265e66758a6 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -210,7 +210,6 @@ class Parameter: def set_parameter_data(self, data): """Set `default_input` of current `Parameter`.""" - self.init_mode = None if isinstance(data, bool): raise ValueError('Parameter data can not be `bool`') if isinstance(data, Tensor): @@ -243,7 +242,8 @@ class Parameter: set_sliced (bool): True if should set parameter sliced after init the data of initializer. Default: False. """ - if self.init_mode is None: + if isinstance(self.default_input, Tensor): + # skip if data already initialized. return if layout is not None: if not isinstance(layout, list): diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 3c66e0a6dfb940a32dc1e7788be38607227cbc74..492e74922a270c0543a386c3cdb930d7869b5c4b 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -134,3 +134,19 @@ def test_check_str_by_regular(): _check_str_by_regular(str5) with pytest.raises(ValueError): _check_str_by_regular(str6) + +def test_parameter_lazy_init(): + # Call init_data() without set default_input. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') + assert not isinstance(para.default_input, Tensor) + para.init_data() + assert isinstance(para.default_input, Tensor) + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + + # Call init_data() after default_input is set. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') + assert not isinstance(para.default_input, Tensor) + para.default_input = Tensor(np.zeros((1, 2, 3))) + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) + para.init_data() # expect no effect. + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))