提交 89fce0e4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2040 fix paramter is metatensor bug in pynative mode

Merge pull request !2040 from flywind/fix_pynative_bug
......@@ -16,6 +16,7 @@
"""Parameter for cell."""
import numbers
from copy import copy, deepcopy
from mindspore import context
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
......@@ -61,6 +62,8 @@ class Parameter:
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
def __repr__(self):
format_str = 'Parameter (name={name})'
......@@ -142,6 +145,8 @@ class Parameter:
if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)
......
......@@ -202,7 +202,6 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs)
return out
self.init_parameters_data()
orign_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册