提交 ef2f03c1 编写于 作者: C chengang 提交者: 高东海

回退 'Pull Request !133 : Edit loss_scale to fit GPU'

上级 cdbef85e
...@@ -25,7 +25,6 @@ from ...ops import operations as P ...@@ -25,7 +25,6 @@ from ...ops import operations as P
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \ from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
ControlDepend ControlDepend
from ...common import dtype as mstype from ...common import dtype as mstype
import mindspore.context as context
_grad_scale = C.MultitypeFuncGraph("grad_scale") _grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal() reciprocal = P.Reciprocal()
...@@ -35,12 +34,6 @@ reciprocal = P.Reciprocal() ...@@ -35,12 +34,6 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad)) return grad * F.cast(reciprocal(scale), F.dtype(grad))
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()
@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
class DynamicLossScaleUpdateCell(Cell): class DynamicLossScaleUpdateCell(Cell):
r""" r"""
...@@ -204,15 +197,9 @@ class TrainOneStepWithLossScaleCell(Cell): ...@@ -204,15 +197,9 @@ class TrainOneStepWithLossScaleCell(Cell):
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU": self.alloc_status = NPUAllocFloatStatus()
self.gpu_target = True self.get_status = NPUGetFloatStatus()
self.float_status = P.FloatStatus() self.clear_status = NPUClearFloatStatus()
self.addn = P.AddN()
else:
self.gpu_target = False
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False) self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32) self.base = Tensor(1, mstype.float32)
self.less_equal = LessEqual() self.less_equal = LessEqual()
...@@ -237,11 +224,10 @@ class TrainOneStepWithLossScaleCell(Cell): ...@@ -237,11 +224,10 @@ class TrainOneStepWithLossScaleCell(Cell):
def construct(self, data, label, sens=None): def construct(self, data, label, sens=None):
weights = self.weights weights = self.weights
loss = self.network(data, label) loss = self.network(data, label)
if not self.gpu_target: # init overflow buffer
# init overflow buffer init = self.alloc_status()
init = self.alloc_status() # clear overflow buffer
# clear overflow buffer self.clear_status(init)
self.clear_status(init)
if sens is None: if sens is None:
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
...@@ -251,14 +237,10 @@ class TrainOneStepWithLossScaleCell(Cell): ...@@ -251,14 +237,10 @@ class TrainOneStepWithLossScaleCell(Cell):
if self.reducer_flag: if self.reducer_flag:
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
if not self.gpu_target: # get the overflow buffer
# get the overflow buffer self.get_status(init)
self.get_status(init) # sum overflow buffer elements, 0:not overflow , >0:overflow
# sum overflow buffer elements, 0:not overflow , >0:overflow flag_sum = self.reduce_sum(init, (0,))
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
if self.is_distributed: if self.is_distributed:
# sum overflow flag over devices # sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum) flag_reduce = self.allreduce(flag_sum)
......
...@@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul ...@@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
LogicalNot, LogicalOr, MatMul, Maximum, LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, NPUGetFloatStatus, Pow, RealDiv,
Reciprocal, CumSum, Reciprocal, CumSum,
Sin, Sqrt, Rsqrt, Sin, Sqrt, Rsqrt,
Square, Sub, TensorAdd, Sign, Round) Square, Sub, TensorAdd, Sign, Round)
...@@ -151,10 +151,6 @@ __all__ = [ ...@@ -151,10 +151,6 @@ __all__ = [
'Neg', 'Neg',
'Slice', 'Slice',
'DType', 'DType',
'IsNan',
'IsInf',
'IsFinite',
'FloatStatus',
'NPUAllocFloatStatus', 'NPUAllocFloatStatus',
'NPUGetFloatStatus', 'NPUGetFloatStatus',
'NPUClearFloatStatus', 'NPUClearFloatStatus',
......
...@@ -1541,89 +1541,6 @@ class LogicalOr(_LogicBinaryOp): ...@@ -1541,89 +1541,6 @@ class LogicalOr(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
class IsNan(PrimitiveWithInfer):
"""
Judging which elements are nan for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input.
"""
@prim_attr_register
def __init__(self):
"""init IsNan"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsInf(PrimitiveWithInfer):
"""
Judging which elements are inf or -inf for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input.
"""
@prim_attr_register
def __init__(self):
"""init IsInf"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class IsFinite(PrimitiveWithInfer):
"""
Judging which elements are finite for each position
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape of input.
"""
@prim_attr_register
def __init__(self):
"""init IsFinite"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return mstype.bool_
class FloatStatus(PrimitiveWithInfer):
"""
Determine if the elements contains nan, inf or -inf
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the shape of `(1,)`.
"""
@prim_attr_register
def __init__(self):
"""init FloatStatus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return [1]
def infer_dtype(self, x_dtype):
return x_dtype
class NPUAllocFloatStatus(PrimitiveWithInfer): class NPUAllocFloatStatus(PrimitiveWithInfer):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册