diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 780b5fe3671ef5fd6c6e395657e3364acd966137..fc0498f342316b1efab3adca006a6a627cf9221b 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -15,6 +15,7 @@ """builtin_operations""" import numpy as np from mindspore.ops import functional as F +from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype @@ -173,11 +174,11 @@ def stop_gradient(x): """Implement `stop_gradient`.""" return x + +hyper_map = C.HyperMap() + def mixed_precision_cast(dst_type, x): """Implement `mixed_precision_cast`.""" - if isinstance(x, tuple): - res = list() - for item in x: - res.append(F.cast(item, dst_type)) - return tuple(res) - return F.cast(x, dst_type) + def cast_inner(data): + return F.cast(data, dst_type) + return hyper_map(cast_inner, x) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 1a79f56526204773cf6fa313da52d265e66758a6..28c7c9fa4d7ebede1d48e8f07d357c883bc4b5d5 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -61,6 +61,7 @@ class Parameter: self._is_init = False self._sliced = False self.is_param_ps = False + self._cast_type = None self.init_in_server = False if context.get_context("mode") == context.PYNATIVE_MODE: self.init_data() @@ -103,6 +104,16 @@ class Parameter: raise ValueError("The type of the name should be `str` or `None`.") self._value.name = name_ + @property + def cast_type(self): + return self._cast_type + + @cast_type.setter + def cast_type(self, dst_type): + if dst_type not in (mstype.float16, mstype.float32, None): + raise ValueError("The type of the name should be type of [float32, float16] or `None`.") + self._cast_type = dst_type + @property def sliced(self): """Get slice status of the parameter.""" diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 45ffc9fd8ca3f124667e015049c801e0d634b228..10be325b80c20592fc70be0fe2d36808fa564aba 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -278,7 +278,7 @@ class SparseTensor: Returns: SparseTensor, composed of `indices`, `values`, `dense_shape`. - Examples: + Examples: >>> class Net(nn.Cell): >>> def __init__(self, dense_shape): >>> super(Net, self).__init__() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 1dbaae2fa6eb25b299155057c2ad7f6dcc903231..b474732c2e15df78285d7e98a2882691ca81a359 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -286,6 +286,8 @@ class Cell: if context.get_context("mode") == context.PYNATIVE_MODE: if name in self.__dict__: del self.__dict__[name] + if name in params: + del params[name] params_list[name] = value else: object.__setattr__(self, name, value) @@ -499,9 +501,11 @@ class Cell: """ if hasattr(self, "_mindspore_flags"): if self._mindspore_flags.get('fp16'): - return cast(param, mstype.float16) - if self._mindspore_flags.get('fp32'): - return cast(param, mstype.float32) + param.cast_type = mstype.float16 + elif self._mindspore_flags.get('fp32'): + param.cast_type = mstype.float32 + else: + param.cast_type = None return param def insert_child_to_cell(self, child_name, child): diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index c6586996bae7827bb15abd23a0cf38d0aeeb9e49..a82cf13450354e5aac3c692212d29d2232f46a9e 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('shape', shape) #support GE backend for no compare operators tensor_operator_registry.register('vm_compare', BP.vm_compare) +tensor_operator_registry.register('cast', cast) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 8ca6521e81162b60774b9c1e1bcfd2269e908e4e..c0d6cc4d0cacd076843746a41644a2ce57dbd6af 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive): self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) + self._update_parameter = True class BNTrainingReduce(PrimitiveWithInfer): diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index cb34e9ff24b2e7bb21fdbe69cfc0f6e58f728b47..8aa72f1f5e89f0e83f3ebefd064b5aa1655eb2bb 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -18,6 +18,8 @@ import inspect import copy from mindspore.common.api import _wrap_func +from mindspore.common import Parameter +from mindspore.common._register_for_tensor import tensor_operator_registry from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import signature_rw as sig_rw from .._c_expression import signature_kind as sig_kind @@ -49,6 +51,7 @@ class Primitive(Primitive_): self.name = name self.attrs = {} self.init_attrs = {"name": name} + self._update_parameter = False Primitive_.__init__(self, name, self) if hasattr(self.__class__, '__mindspore_signature__'): sig = self._fill_signature(self.__class__.__mindspore_signature__) @@ -189,6 +192,11 @@ class Primitive(Primitive_): # for checking output number with kernel implementation self.add_prim_attr("output_names", outputs) + @property + def update_parameter(self): + """ Whether the primitive will update the value of parameter.""" + return self._update_parameter + class PrimitiveWithInfer(Primitive): """ @@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None): @_wrap_func def _run_op(obj, op_name, args): """Single op execution function supported by ge in PyNative mode.""" - output = real_run_op(obj, op_name, args) + cast = tensor_operator_registry.get("cast") + if op_name == "Cast" or obj.update_parameter: + cast_args = args + else: + cast_args = list() + for arg in args: + if isinstance(arg, Parameter): + if arg.cast_type: + cast_args.append(cast(arg, arg.cast_type)) + else: + cast_args.append(arg) + else: + cast_args.append(arg) + output = real_run_op(obj, op_name, tuple(cast_args)) if not output: raise RuntimeError("Pynative run op %s failed!" % op_name) if len(output) == 1: