提交 778fdf6e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2672 auto paralle for sparse tensor gradient

Merge pull request !2672 from lirongzhen1/r0.5
......@@ -16,18 +16,22 @@
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
import mindspore.common.dtype as mstype
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
_all_reduce = AllReduce()
_all_gather = None
def _init_optimizer_allreduce():
def _init_optimizer_communication():
global _all_reduce
global _all_gather
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce.add_prim_attr('fusion', 1)
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
......@@ -72,8 +76,8 @@ def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
degree = F.scalar_cast(degree, F.dtype(grad[1]))
dout = _all_gather(grad[1])
cast_op = P.Cast()
dout = mul(dout, cast_op(F.scalar_to_array(1.0/degree), F.dtype(dout)))
grad = (indices, dout, dout[2])
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2])
return grad
......@@ -110,7 +114,7 @@ def _tensors_allreduce_with_sparse(allreduce_filter, grad):
if allreduce_filter:
indices = _all_gather(grad[0])
dout = _all_gather(grad[1])
grad = (indices, dout, dout[2])
grad = (indices, dout, grad[2])
return grad
......@@ -131,6 +135,20 @@ def _tensors_get_datatype(grad):
return F.dtype(grad)
@_get_datatype.register("Tuple")
def _tensors_get_datatype_with_sparse(grad):
"""
Acquire gradient datatype.
Args:
grad (Tuple): The gradient tensor before operation.
Returns:
mstype, the datatype of gradient.
"""
return F.dtype(grad[1])
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
......@@ -149,6 +167,22 @@ def _tensors_cast_datatype(datatype, grad):
return F.cast(grad, datatype)
@_cast_datatype.register("TypeType", "Tuple")
def _tensors_cast_datatype_with_sparse(datatype, grad):
"""
Cast gradient to datatype.
Args:
datatype (mstype): the destination datatype of gradient.
grad (Tuple): The gradient tensor before operation.
Returns:
Tuple, the gradient tuple after operation.
"""
dout = F.cast(grad[1], datatype)
return (grad[0], dout, grad[2])
class DistributedGradReducer(Cell):
"""
A distributed optimizer.
......@@ -224,7 +258,7 @@ class DistributedGradReducer(Cell):
def __init__(self, parameters, mean=True, degree=None):
super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap()
self.map_ = C.Map()
self.mul = P.Mul()
if degree is None:
self.degree = get_group_size()
......@@ -234,19 +268,27 @@ class DistributedGradReducer(Cell):
self.degree = degree
self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_allreduce()
_init_optimizer_communication()
def construct(self, grads):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
"""
In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
and cast back after the operation.
Args:
grads (Union[Tensor, tuple[Tensor]]): The gradient tensor or tuple before operation.
Returns:
new_grads (Union[Tensor, tuple[Tensor]]), the gradient tensor or tuple after operation.
"""
datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean:
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册