diff --git a/mindspore/nn/wrap/grad_reducer.py b/mindspore/nn/wrap/grad_reducer.py index 6989a1a4a48caab1326d9fd64e36988ed1776033..930cabf478b90f717b86e4cac1df6ef9af53ccec 100644 --- a/mindspore/nn/wrap/grad_reducer.py +++ b/mindspore/nn/wrap/grad_reducer.py @@ -27,32 +27,26 @@ reduce_opt = C.MultitypeFuncGraph("reduce_opt") def _init_allreduce_operators(length): """ initialize allreduce communication operators""" - is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") - split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() - if is_parallel_optimizer and split_indices: - group = 1 - fusion = () - for i in range(length): - fusion = fusion + (group,) - if split_indices[group - 1] <= i + 1: - if group >= len(split_indices): - continue - group = group + 1 - index = tuple(range(1, length + 1)) - else: - fusion = (1,) * length - index = (0,) * length - opt_list = () + group = 1 + fusion = () for i in range(length): - opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) - opt.add_prim_attr('fusion', fusion[i]) - opt.add_prim_attr('index', index[i]) - opt_list = opt_list + (opt,) - return opt_list + fusion = fusion + (group,) + if split_indices[group - 1] <= i + 1: + if group >= len(split_indices): + continue + group = group + 1 + index = tuple(range(1, length + 1)) + op_list = () + for i in range(length): + op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) + op.add_prim_attr('fusion', fusion[i]) + op.add_prim_attr('index', index[i]) + op_list = op_list + (op,) + return op_list -@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function", "Bool") -def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce, ps_parameter): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") +def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): """ Apply allreduce on gradient. @@ -60,9 +54,10 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc degree (int): The mean coefficient. mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. allgather (Primitive): The communication operator for sparse gradients. + allreduce (Primitive): The communication operator for gradients. allreduce_filter (bool): When it is true, allreduce would apply. grad (Tensor): The gradient tensor before operation. - allreduce (Primitive): The communication operator for gradients. + ps_parameter(Bool): Use parameter server or not. Returns: Tensor, the gradient tensor after operation. @@ -78,8 +73,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc return grad -@reduce_opt.register("Number", "Bool", "Function", "Bool", "IndexedSlices", "Function") -def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce): +@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") +def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): """ Apply allgather on gradient instead of allreduce for sparse feature. Allgather is a communication operation used for distributed deep learning. @@ -88,9 +83,9 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr degree (int): The mean coefficient. mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. allgather (Primitive): The communication operator for sparse gradients. - allreduce_filter (bool): When it is true, allgather would apply. - grad (IndexedSlices): The gradient before operation. allreduce (Primitive): The communication operator for gradients. + allreduce_filter (bool): When it is true, allgather would apply. + grad (tuple): The indices, gradient tensor and tensor_shape before operation. Returns: IndexedSlices, the gradient after operation. @@ -256,7 +251,14 @@ class DistributedGradReducer(Cell): self.degree = degree self.mean = mean self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - self.opt_list = _init_allreduce_operators(len(parameters)) + is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer") + split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() + if is_parallel_optimizer and split_indices: + self.split_fusion = True + self.op_list = _init_allreduce_operators(len(parameters)) + else: + self.split_fusion = False + self.allreduce = AllReduce().add_prim_attr('fusion', 1) self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP) ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in parameters) @@ -275,8 +277,11 @@ class DistributedGradReducer(Cell): """ datatypes = self.map_(F.partial(_get_datatype), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) - new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), - self.allreduce_filter, grads, self.opt_list, self.ps_parameters) - + if self.split_fusion: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), + self.opt_list, self.allreduce_filter, grads, self.ps_parameters) + else: + new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, + self.allreduce), self.allreduce_filter, grads, self.ps_parameters) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) return new_grad