提交 fab9fac1 编写于 作者: K kingfo

fix batchnorm under mix precision in pynative mode

上级 b75943f2
......@@ -15,6 +15,7 @@
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
......@@ -173,11 +174,11 @@ def stop_gradient(x):
"""Implement `stop_gradient`."""
return x
hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)
def cast_inner(data):
return F.cast(data, dst_type)
return hyper_map(cast_inner, x)
......@@ -61,6 +61,7 @@ class Parameter:
self._is_init = False
self._sliced = False
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
......@@ -103,6 +104,16 @@ class Parameter:
raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_
@property
def cast_type(self):
return self._cast_type
@cast_type.setter
def cast_type(self, dst_type):
if dst_type not in (mstype.float16, mstype.float32, None):
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
self._cast_type = dst_type
@property
def sliced(self):
"""Get slice status of the parameter."""
......
......@@ -286,6 +286,8 @@ class Cell:
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
......@@ -499,9 +501,11 @@ class Cell:
"""
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
param.cast_type = mstype.float16
elif self._mindspore_flags.get('fp32'):
param.cast_type = mstype.float32
else:
param.cast_type = None
return param
def insert_child_to_cell(self, child_name, child):
......
......@@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)
......@@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self._update_parameter = True
class BNTrainingReduce(PrimitiveWithInfer):
......
......@@ -18,6 +18,8 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
......@@ -49,6 +51,7 @@ class Primitive(Primitive_):
self.name = name
self.attrs = {}
self.init_attrs = {"name": name}
self._update_parameter = False
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
......@@ -189,6 +192,11 @@ class Primitive(Primitive_):
# for checking output number with kernel implementation
self.add_prim_attr("output_names", outputs)
@property
def update_parameter(self):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter
class PrimitiveWithInfer(Primitive):
"""
......@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
output = real_run_op(obj, op_name, args)
cast = tensor_operator_registry.get("cast")
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list()
for arg in args:
if isinstance(arg, Parameter):
if arg.cast_type:
cast_args.append(cast(arg, arg.cast_type))
else:
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册