提交 0d838c7c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!161 Edit loss_scale to fit GPU

Merge pull request !161 from VectorSL/master
......@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Loss scale cell for loss scale training."""
import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
......@@ -34,6 +35,13 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class DynamicLossScaleUpdateCell(Cell):
r"""
......@@ -197,9 +205,15 @@ class TrainOneStepWithLossScaleCell(Cell):
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
else:
self.gpu_target = False
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = LessEqual()
......@@ -224,10 +238,12 @@ class TrainOneStepWithLossScaleCell(Cell):
def construct(self, data, label, sens=None):
weights = self.weights
loss = self.network(data, label)
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
init = False
if not self.gpu_target:
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if sens is None:
scaling_sens = self.loss_scale
else:
......@@ -238,9 +254,13 @@ class TrainOneStepWithLossScaleCell(Cell):
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
if not self.gpu_target:
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
......
......@@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt,
Square, Sub, TensorAdd, Sign, Round)
......@@ -154,6 +154,10 @@ __all__ = [
'NPUAllocFloatStatus',
'NPUGetFloatStatus',
'NPUClearFloatStatus',
'IsNan',
'IsFinite',
'IsInf',
'FloatStatus',
'Reciprocal',
'SmoothL1Loss',
'ReduceAll',
......
......@@ -1541,6 +1541,94 @@ class LogicalOr(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
class IsNan(PrimitiveWithInfer):
"""
Judging which elements are nan for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
"""
@prim_attr_register
def __init__(self):
"""init IsNan"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsInf(PrimitiveWithInfer):
"""
Judging which elements are inf or -inf for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
"""
@prim_attr_register
def __init__(self):
"""init IsInf"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsFinite(PrimitiveWithInfer):
"""
Judging which elements are finite for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input, and the dtype is bool.
"""
@prim_attr_register
def __init__(self):
"""init IsFinite"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class FloatStatus(PrimitiveWithInfer):
"""
Determine if the elements contains nan, inf or -inf. `0` for normal, `1` for overflow.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or
`mindspore.dtype.float16`.
"""
@prim_attr_register
def __init__(self):
"""init FloatStatus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return [1]
def infer_dtype(self, x_dtype):
return x_dtype
class NPUAllocFloatStatus(PrimitiveWithInfer):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册