From a6db29b11b15f1d361e110df2bca44d006de34d5 Mon Sep 17 00:00:00 2001 From: He Wei Date: Mon, 27 Jul 2020 14:58:27 +0800 Subject: [PATCH] Fix a bug for Parameter 1. Parameter's init_data() should have no effect if default_input already set; 2. This bug is introduced by 'decouple ParamValue from python'; 3. An unit test case is added to ensure the right init_data() behavior. --- mindspore/common/parameter.py | 4 ++-- tests/ut/python/nn/test_parameter.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 4f45f73e8..1a79f5652 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -210,7 +210,6 @@ class Parameter: def set_parameter_data(self, data): """Set `default_input` of current `Parameter`.""" - self.init_mode = None if isinstance(data, bool): raise ValueError('Parameter data can not be `bool`') if isinstance(data, Tensor): @@ -243,7 +242,8 @@ class Parameter: set_sliced (bool): True if should set parameter sliced after init the data of initializer. Default: False. """ - if self.init_mode is None: + if isinstance(self.default_input, Tensor): + # skip if data already initialized. return if layout is not None: if not isinstance(layout, list): diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 3c66e0a6d..492e74922 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -134,3 +134,19 @@ def test_check_str_by_regular(): _check_str_by_regular(str5) with pytest.raises(ValueError): _check_str_by_regular(str6) + +def test_parameter_lazy_init(): + # Call init_data() without set default_input. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') + assert not isinstance(para.default_input, Tensor) + para.init_data() + assert isinstance(para.default_input, Tensor) + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + + # Call init_data() after default_input is set. + para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') + assert not isinstance(para.default_input, Tensor) + para.default_input = Tensor(np.zeros((1, 2, 3))) + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) + para.init_data() # expect no effect. + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) -- GitLab