提交 fcdad59c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3594 fix batchnorm issue under mix precision in pynative mode

Merge pull request !3594 from wangqiuliang/fix-batchnorm-under-mix-precision-in-pynative
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""builtin_operations""" """builtin_operations"""
import numpy as np import numpy as np
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
...@@ -173,11 +174,11 @@ def stop_gradient(x): ...@@ -173,11 +174,11 @@ def stop_gradient(x):
"""Implement `stop_gradient`.""" """Implement `stop_gradient`."""
return x return x
hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x): def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`.""" """Implement `mixed_precision_cast`."""
if isinstance(x, tuple): def cast_inner(data):
res = list() return F.cast(data, dst_type)
for item in x: return hyper_map(cast_inner, x)
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)
...@@ -61,6 +61,7 @@ class Parameter: ...@@ -61,6 +61,7 @@ class Parameter:
self._is_init = False self._is_init = False
self._sliced = False self._sliced = False
self.is_param_ps = False self.is_param_ps = False
self._cast_type = None
self.init_in_server = False self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data() self.init_data()
...@@ -103,6 +104,16 @@ class Parameter: ...@@ -103,6 +104,16 @@ class Parameter:
raise ValueError("The type of the name should be `str` or `None`.") raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_ self._value.name = name_
@property
def cast_type(self):
return self._cast_type
@cast_type.setter
def cast_type(self, dst_type):
if dst_type not in (mstype.float16, mstype.float32, None):
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
self._cast_type = dst_type
@property @property
def sliced(self): def sliced(self):
"""Get slice status of the parameter.""" """Get slice status of the parameter."""
......
...@@ -278,7 +278,7 @@ class SparseTensor: ...@@ -278,7 +278,7 @@ class SparseTensor:
Returns: Returns:
SparseTensor, composed of `indices`, `values`, `dense_shape`. SparseTensor, composed of `indices`, `values`, `dense_shape`.
Examples: Examples:
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
>>> def __init__(self, dense_shape): >>> def __init__(self, dense_shape):
>>> super(Net, self).__init__() >>> super(Net, self).__init__()
......
...@@ -286,6 +286,8 @@ class Cell: ...@@ -286,6 +286,8 @@ class Cell:
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__: if name in self.__dict__:
del self.__dict__[name] del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value params_list[name] = value
else: else:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
...@@ -499,9 +501,11 @@ class Cell: ...@@ -499,9 +501,11 @@ class Cell:
""" """
if hasattr(self, "_mindspore_flags"): if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'): if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16) param.cast_type = mstype.float16
if self._mindspore_flags.get('fp32'): elif self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32) param.cast_type = mstype.float32
else:
param.cast_type = None
return param return param
def insert_child_to_cell(self, child_name, child): def insert_child_to_cell(self, child_name, child):
......
...@@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge) ...@@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape) tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators #support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare) tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)
...@@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive): ...@@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self._update_parameter = True
class BNTrainingReduce(PrimitiveWithInfer): class BNTrainingReduce(PrimitiveWithInfer):
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
import inspect import inspect
import copy import copy
from mindspore.common.api import _wrap_func from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind from .._c_expression import signature_kind as sig_kind
...@@ -49,6 +51,7 @@ class Primitive(Primitive_): ...@@ -49,6 +51,7 @@ class Primitive(Primitive_):
self.name = name self.name = name
self.attrs = {} self.attrs = {}
self.init_attrs = {"name": name} self.init_attrs = {"name": name}
self._update_parameter = False
Primitive_.__init__(self, name, self) Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'): if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__) sig = self._fill_signature(self.__class__.__mindspore_signature__)
...@@ -189,6 +192,11 @@ class Primitive(Primitive_): ...@@ -189,6 +192,11 @@ class Primitive(Primitive_):
# for checking output number with kernel implementation # for checking output number with kernel implementation
self.add_prim_attr("output_names", outputs) self.add_prim_attr("output_names", outputs)
@property
def update_parameter(self):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter
class PrimitiveWithInfer(Primitive): class PrimitiveWithInfer(Primitive):
""" """
...@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None): ...@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func @_wrap_func
def _run_op(obj, op_name, args): def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode.""" """Single op execution function supported by ge in PyNative mode."""
output = real_run_op(obj, op_name, args) cast = tensor_operator_registry.get("cast")
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list()
for arg in args:
if isinstance(arg, Parameter):
if arg.cast_type:
cast_args.append(cast(arg, arg.cast_type))
else:
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
if not output: if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name) raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1: if len(output) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册