diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 3be5472402b6a8f93dee93bf8901537e9d6b9804..a943cbcf63a92f96cfcd16bf087e4cef1a11f3cf 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -16,6 +16,7 @@ """Parameter for cell.""" import numbers from copy import copy, deepcopy +from mindspore import context from . import dtype as mstype from .initializer import initializer, Initializer from .tensor import Tensor, MetaTensor @@ -61,6 +62,8 @@ class Parameter: self._is_init = False self._sliced = False self.clone_info = _CloneInfo() + if context.get_context("mode") == context.PYNATIVE_MODE: + self.init_data() def __repr__(self): format_str = 'Parameter (name={name})' @@ -142,6 +145,8 @@ class Parameter: if isinstance(init, (str, Initializer, numbers.Number)): x.init_mode = initializer(init, shape=shape, dtype=dtype) x.default_input = MetaTensor(dtype, shape) + if context.get_context("mode") == context.PYNATIVE_MODE: + x.init_data() else: x.default_input = initializer(init, shape=shape, dtype=dtype) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 1b3ebb828c175ad84ed3e3302db285ad8fdc206a..4ea3eb53646cd9ca77ee175d1284d81adc8913fe 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -202,7 +202,6 @@ class Cell: if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out - self.init_parameters_data() orign_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True)