提交 170d74c8 编写于 作者: Z Ziyan

use one allreduce in grad reducer

上级 b5d8dad4
...@@ -27,32 +27,26 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt") ...@@ -27,32 +27,26 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt")
def _init_allreduce_operators(length): def _init_allreduce_operators(length):
""" initialize allreduce communication operators""" """ initialize allreduce communication operators"""
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") group = 1
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() fusion = ()
if is_parallel_optimizer and split_indices:
group = 1
fusion = ()
for i in range(length):
fusion = fusion + (group,)
if split_indices[group - 1] <= i + 1:
if group >= len(split_indices):
continue
group = group + 1
index = tuple(range(1, length + 1))
else:
fusion = (1,) * length
index = (0,) * length
opt_list = ()
for i in range(length): for i in range(length):
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) fusion = fusion + (group,)
opt.add_prim_attr('fusion', fusion[i]) if split_indices[group - 1] <= i + 1:
opt.add_prim_attr('index', index[i]) if group >= len(split_indices):
opt_list = opt_list + (opt,) continue
return opt_list group = group + 1
index = tuple(range(1, length + 1))
op_list = ()
for i in range(length):
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', fusion[i])
op.add_prim_attr('index', index[i])
op_list = op_list + (op,)
return op_list
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool") @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter): def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
""" """
Apply allreduce on gradient. Apply allreduce on gradient.
...@@ -60,9 +54,10 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc ...@@ -60,9 +54,10 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
degree (int): The mean coefficient. degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients. allgather (Primitive): The communication operator for sparse gradients.
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply. allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation. grad (Tensor): The gradient tensor before operation.
allreduce (Primitive): The communication operator for gradients. ps_parameter(Bool): Use parameter server or not.
Returns: Returns:
Tensor, the gradient tensor after operation. Tensor, the gradient tensor after operation.
...@@ -78,8 +73,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc ...@@ -78,8 +73,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
return grad return grad
@reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function") @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning. Allgather is a communication operation used for distributed deep learning.
...@@ -88,9 +83,9 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr ...@@ -88,9 +83,9 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr
degree (int): The mean coefficient. degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients. allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allgather would apply.
grad (IndexedSlices): The gradient before operation.
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
Returns: Returns:
IndexedSlices, the gradient after operation. IndexedSlices, the gradient after operation.
...@@ -256,7 +251,14 @@ class DistributedGradReducer(Cell): ...@@ -256,7 +251,14 @@ class DistributedGradReducer(Cell):
self.degree = degree self.degree = degree
self.mean = mean self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
self.opt_list = _init_allreduce_operators(len(parameters)) is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
if is_parallel_optimizer and split_indices:
self.split_fusion = True
self.op_list = _init_allreduce_operators(len(parameters))
else:
self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', 1)
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
ps_filter = lambda x: x.is_param_ps ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in parameters) self.ps_parameters = tuple(ps_filter(x) for x in parameters)
...@@ -275,8 +277,11 @@ class DistributedGradReducer(Cell): ...@@ -275,8 +277,11 @@ class DistributedGradReducer(Cell):
""" """
datatypes = self.map_(F.partial(_get_datatype), grads) datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), if self.split_fusion:
self.allreduce_filter, grads, self.opt_list, self.ps_parameters) new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.opt_list, self.allreduce_filter, grads, self.ps_parameters)
else:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
self.allreduce), self.allreduce_filter, grads, self.ps_parameters)
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad return new_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册