提交 09b43728 编写于 作者: K kingfo

support mix precision in pynative

上级 8f4bab4e
......@@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend
from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..ops.functional import cast
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor
......@@ -60,6 +61,7 @@ class Cell:
def __init__(self, auto_prefix=True, flags=None):
self._params = OrderedDict()
self._cells = OrderedDict()
self._params_list = OrderedDict()
self.training = False
self.requires_grad = False
self.pynative = False
......@@ -188,11 +190,22 @@ class Cell:
if '_params' in self.__dict__:
params = self.__dict__['_params']
if name in params:
if context.get_context("mode") == context.PYNATIVE_MODE:
return self.cast_param(params[name])
return params[name]
if '_cells' in self.__dict__:
cells = self.__dict__['_cells']
if name in cells:
return cells[name]
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__:
params_list = self.__dict__['_params_list']
if name in params_list:
para_list = params_list[name]
cast_list = list()
for para in para_list:
cast_list.append(self.cast_param(para))
para_list = ParameterTuple(cast_list)
return para_list
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
def __del__(self):
......@@ -225,10 +238,21 @@ class Cell:
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
cast_inputs = list()
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float16))
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
for item in inputs:
cast_inputs.append(cast(item, mstype.float32))
if cast_inputs:
cast_inputs = tuple(cast_inputs)
else:
cast_inputs = inputs
if self.enable_hook:
output = self._hook_construct(*inputs)
output = self._hook_construct(*cast_inputs)
else:
output = self.construct(*inputs)
output = self.construct(*cast_inputs)
if isinstance(output, Parameter):
output = output.data
if self.requires_grad is True:
......@@ -241,6 +265,7 @@ class Cell:
def __setattr__(self, name, value):
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
if isinstance(value, Parameter):
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
......@@ -256,7 +281,12 @@ class Cell:
raise AttributeError("Can not assign params before Cell.__init__() call.")
for item in value:
self.insert_param_to_cell(item.name, item, check_name=False)
object.__setattr__(self, name, value)
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
elif isinstance(value, Cell):
if cells is None:
raise AttributeError("Can not assign cells before Cell.__init__() call.")
......@@ -458,6 +488,19 @@ class Cell:
raise TypeError("The type of parameter should be 'Parameter' if not None.")
self._params[param_name] = param
def cast_param(self, param):
"""
Cast parameter according to auto mix precison level in pynative mode.
Args:
param (Parameter): The parameter to cast.
"""
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
return param
def insert_child_to_cell(self, child_name, child):
"""
Adds a child cell to the current cell.
......
......@@ -23,7 +23,7 @@ from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import ParameterTuple, Parameter
context.set_context(device_target='CPU')
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class LstmNet(nn.Cell):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册