提交 89fce0e4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2040 fix paramter is metatensor bug in pynative mode

Merge pull request !2040 from flywind/fix_pynative_bug
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Parameter for cell.""" """Parameter for cell."""
import numbers import numbers
from copy import copy, deepcopy from copy import copy, deepcopy
from mindspore import context
from . import dtype as mstype from . import dtype as mstype
from .initializer import initializer, Initializer from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor from .tensor import Tensor, MetaTensor
...@@ -61,6 +62,8 @@ class Parameter: ...@@ -61,6 +62,8 @@ class Parameter:
self._is_init = False self._is_init = False
self._sliced = False self._sliced = False
self.clone_info = _CloneInfo() self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
def __repr__(self): def __repr__(self):
format_str = 'Parameter (name={name})' format_str = 'Parameter (name={name})'
...@@ -142,6 +145,8 @@ class Parameter: ...@@ -142,6 +145,8 @@ class Parameter:
if isinstance(init, (str, Initializer, numbers.Number)): if isinstance(init, (str, Initializer, numbers.Number)):
x.init_mode = initializer(init, shape=shape, dtype=dtype) x.init_mode = initializer(init, shape=shape, dtype=dtype)
x.default_input = MetaTensor(dtype, shape) x.default_input = MetaTensor(dtype, shape)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else: else:
x.default_input = initializer(init, shape=shape, dtype=dtype) x.default_input = initializer(init, shape=shape, dtype=dtype)
......
...@@ -202,7 +202,6 @@ class Cell: ...@@ -202,7 +202,6 @@ class Cell:
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
out = self.compile_and_run(*inputs) out = self.compile_and_run(*inputs)
return out return out
self.init_parameters_data()
orign_grad = [] orign_grad = []
if self.requires_grad is True: if self.requires_grad is True:
_pynative_exec.set_grad_flag(True) _pynative_exec.set_grad_flag(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册