From 09b437285b2765cae11d77a7fdf3cca74be3fa96 Mon Sep 17 00:00:00 2001 From: kingfo Date: Tue, 21 Jul 2020 22:15:26 +0800 Subject: [PATCH] support mix precision in pynative --- mindspore/nn/cell.py | 49 ++++++++++++++++++++++++++++++-- tests/st/ops/cpu/test_lstm_op.py | 2 +- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 2209a3f96..d7e18c67f 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple from .._c_expression import init_backend from ..ops.primitive import Primitive from ..ops.operations import HookBackward +from ..ops.functional import cast from ..parallel._tensor import _load_tensor_by_layout from ..common.tensor import Tensor @@ -60,6 +61,7 @@ class Cell: def __init__(self, auto_prefix=True, flags=None): self._params = OrderedDict() self._cells = OrderedDict() + self._params_list = OrderedDict() self.training = False self.requires_grad = False self.pynative = False @@ -188,11 +190,22 @@ class Cell: if '_params' in self.__dict__: params = self.__dict__['_params'] if name in params: + if context.get_context("mode") == context.PYNATIVE_MODE: + return self.cast_param(params[name]) return params[name] if '_cells' in self.__dict__: cells = self.__dict__['_cells'] if name in cells: return cells[name] + if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__: + params_list = self.__dict__['_params_list'] + if name in params_list: + para_list = params_list[name] + cast_list = list() + for para in para_list: + cast_list.append(self.cast_param(para)) + para_list = ParameterTuple(cast_list) + return para_list raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) def __del__(self): @@ -225,10 +238,21 @@ class Cell: cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) + cast_inputs = list() + if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float16)) + if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): + for item in inputs: + cast_inputs.append(cast(item, mstype.float32)) + if cast_inputs: + cast_inputs = tuple(cast_inputs) + else: + cast_inputs = inputs if self.enable_hook: - output = self._hook_construct(*inputs) + output = self._hook_construct(*cast_inputs) else: - output = self.construct(*inputs) + output = self.construct(*cast_inputs) if isinstance(output, Parameter): output = output.data if self.requires_grad is True: @@ -241,6 +265,7 @@ class Cell: def __setattr__(self, name, value): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') + params_list = self.__dict__.get('_params_list') if isinstance(value, Parameter): if params is None: raise AttributeError("Can not assign params before Cell.__init__() call.") @@ -256,7 +281,12 @@ class Cell: raise AttributeError("Can not assign params before Cell.__init__() call.") for item in value: self.insert_param_to_cell(item.name, item, check_name=False) - object.__setattr__(self, name, value) + if context.get_context("mode") == context.PYNATIVE_MODE: + if name in self.__dict__: + del self.__dict__[name] + params_list[name] = value + else: + object.__setattr__(self, name, value) elif isinstance(value, Cell): if cells is None: raise AttributeError("Can not assign cells before Cell.__init__() call.") @@ -458,6 +488,19 @@ class Cell: raise TypeError("The type of parameter should be 'Parameter' if not None.") self._params[param_name] = param + def cast_param(self, param): + """ + Cast parameter according to auto mix precison level in pynative mode. + + Args: + param (Parameter): The parameter to cast. + """ + if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): + return cast(param, mstype.float16) + if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): + return cast(param, mstype.float32) + return param + def insert_child_to_cell(self, child_name, child): """ Adds a child cell to the current cell. diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index 7992bfbf0..c8174a5f9 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -23,7 +23,7 @@ from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import ParameterTuple, Parameter -context.set_context(device_target='CPU') +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') class LstmNet(nn.Cell): -- GitLab