diff --git a/model_zoo/official/cv/resnet_thor/README.md b/model_zoo/official/cv/resnet_thor/README.md index d01b2fd5c9a00334f4ff536cb01627e1b97d05b8..54285e5d4921b79c2751900dd662162cb9681e26 100644 --- a/model_zoo/official/cv/resnet_thor/README.md +++ b/model_zoo/official/cv/resnet_thor/README.md @@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval" ``` Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. ``` - result: {'top_5_accuracy': 0.9286771766965429, 'top_1_accuracy': 0.7613036171574904} ckpt=train_parallel/resnet-36_5004.ckpt + result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt ``` ## Model Description diff --git a/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py b/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py index dbc7b3796a6dcffaa2e4aa20f039d34d231ab8b5..aef0766571b032fb93416e81681364efb73a8dbf 100644 --- a/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py @@ -12,149 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""grad_reducer_thor""" -import mindspore.common.dtype as mstype -from mindspore.communication.management import GlobalComm, get_group_size +"""grad reducer cell for distributed training""" from mindspore.nn.cell import Cell +from mindspore.communication.management import GlobalComm, get_group_size from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp +from mindspore.ops.operations.comm_ops import AllReduce +import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") -_all_reduce_A = AllReduce() +def _init_allreduce_operators(length, split_indices): + """ initialize allreduce communication operators""" + indices = split_indices[0] + fusion = split_indices[1] + op_list = () + j = 0 + for i in range(length): + if j <= len(indices)-1: + temp = indices[j] + else: + temp = length + if i >= temp: + j = j + 1 + fusion = fusion + 1 + op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) + op.add_prim_attr('fusion', fusion) + op_list = op_list + (op,) + return op_list + + +@reduce_opt.register("Function", "Number", "Function", "Tensor") +def _tensors_allreduce_mean(mul, degree, allreduce, parameters): + """ + Apply allreduce on parameters. -def _init_optimizer_allreduce(group): - global _all_reduce_A - _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) - _all_reduce_A.add_prim_attr('fusion', group) - + Args: + mul(Primitive): The mul operator for parameters. + degree (int): The mean coefficient. + allreduce (Primitive): The communication operator for parameters. + parameters (Tensor): The parameters before operation. -@reduce_opt.register("Function", "Number", "Tensor") -def _tensors_allreduce_mean(mul, degree, grad): - degree = F.scalar_cast(degree, F.dtype(grad)) - grad = _all_reduce_A(grad) + Returns: + Tensor, the parameters after operation. + """ + degree = F.scalar_cast(degree, F.dtype(parameters)) + parameters = allreduce(parameters) cast_op = P.Cast() - return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) - - -@reduce_opt.register("Bool", "Tensor") -def _tensors_allreduce(allreduce_filter, grad): - if allreduce_filter: - return _all_reduce_A(grad) - return grad + return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters))) _get_datatype = C.MultitypeFuncGraph("_get_datatype") @_get_datatype.register("Tensor") -def _tensors_get_datatype(grad): +def _tensors_get_datatype(parameters): """ - Acquire gradient datatype. + Acquire parameters datatype. Args: - grad (Tensor): The gradient tensor before operation. + parameters (Tensor): The parameters before operation. Returns: - mstype, the datatype of gradient. + mstype, the datatype of parameters. """ - return F.dtype(grad) + return F.dtype(parameters) _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") @_cast_datatype.register("TypeType", "Tensor") -def _tensors_cast_datatype(datatype, grad): +def _tensors_cast_datatype(datatype, parameters): """ - Cast gradient to datatype. + Cast parameters to datatype. Args: - datatype (mstype): the destination datatype of gradient. - grad (Tensor): The gradient tensor before operation. + datatype (mstype): the destination datatype of parameters. + parameters (Tensor): The parameters before operation. Returns: - Tensor, the gradient tensor after operation. + Tensor, the parameters after operation. """ - return F.cast(grad, datatype) + return F.cast(parameters, datatype) class DistributedGradReducerThor(Cell): """ A distributed optimizer. - Constructs a gradient reducer Cell, which applies communication and average operations on - single-process gradient values. + Constructs a parameters reducer Cell, which applies communication and average operations on + single-process parameters values. Args: - parameters (list): the parameters to be updated. - mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. + parameter_length (int): length of the parameters to be updated. + split_indices(tuple): parameter split indices. + mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False. degree (int): The mean coefficient. Usually it equals to device number. Default: None. Raises: ValueError: If degree is not a int or less than 0. - - Examples: - >>> from mindspore.communication import init, get_group_size - >>> from mindspore.ops import composite as C - >>> from mindspore.ops import operations as P - >>> from mindspore.ops import functional as F - >>> from mindspore import context - >>> from mindspore import nn - >>> from mindspore import ParameterTuple - >>> from mindspore.context import ParallelMode - >>> - >>> device_id = int(os.environ["DEVICE_ID"]) - >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, - >>> device_id=int(device_id), enable_hccl=True) - >>> init() - >>> context.reset_auto_parallel_context() - >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) - >>> - >>> - >>> class TrainingWrapper(nn.Cell): - >>> def __init__(self, network, optimizer, sens=1.0): - >>> super(TrainingWrapper, self).__init__(auto_prefix=False) - >>> self.network = network - >>> self.network.add_flags(defer_inline=True) - >>> self.weights = ParameterTuple(network.trainable_params()) - >>> self.optimizer = optimizer - >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True) - >>> self.sens = sens - >>> self.reducer_flag = False - >>> self.grad_reducer = None - >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, - >>> ParallelMode.HYBRID_PARALLEL]: - >>> self.reducer_flag = True - >>> if self.reducer_flag: - >>> mean = context.get_auto_parallel_context("gradients_mean") - >>> if mean.get_device_num_is_set(): - >>> degree = context.get_auto_parallel_context("device_num") - >>> else: - >>> degree = get_group_size() - >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) - >>> - >>> def construct(self, *args): - >>> weights = self.weights - >>> loss = self.network(*args) - >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - >>> grads = self.grad(self.network, weights)(*args, sens) - >>> if self.reducer_flag: - >>> # apply grad reducer on grads - >>> grads = self.grad_reducer(grads) - >>> return F.depend(loss, self.optimizer(grads)) - >>> - >>> network = Net() - >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> train_cell = TrainingWrapper(network, optimizer) - >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) - >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) - >>> grads = train_cell(inputs, label) """ - def __init__(self, parameters, group, mean=True, degree=None): + def __init__(self, parameter_length, split_indices, mean=True, degree=None): super(DistributedGradReducerThor, self).__init__(auto_prefix=False) self.hyper_map = C.HyperMap() self.mul = P.Mul() @@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell): raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") self.degree = degree self.mean = mean - self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - _init_optimizer_allreduce(group) - - def construct(self, grads): - # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the - # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, - # and cast back after the operation. - datatypes = self.hyper_map(F.partial(_get_datatype), grads) - grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) - - new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) - new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) - return new_grad + self.op_list = _init_allreduce_operators(parameter_length, split_indices) + + def construct(self, parameters): + datatypes = self.hyper_map(F.partial(_get_datatype), parameters) + parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters) + new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters) + new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters) + return new_parameters diff --git a/model_zoo/official/cv/resnet_thor/src/thor.py b/model_zoo/official/cv/resnet_thor/src/thor.py index 44b6930684ba52dab9455117981dde1f4d2a2acf..d5118386d3a2cf210e62c345ea8a25761f1fdabf 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor.py +++ b/model_zoo/official/cv/resnet_thor/src/thor.py @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype from mindspore._checkparam import check_bool from mindspore._checkparam import Validator as validator from mindspore.nn.optim.optimizer import Optimizer -from mindspore.parallel._utils import _get_device_num, _get_gradients_mean +from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from src.grad_reducer_thor import DistributedGradReducerThor _momentum_opt = C.MultitypeFuncGraph("momentum_opt") @@ -85,10 +85,12 @@ class THOR_GPU(Optimizer): self.assign = P.Assign() self.mul = P.Mul() - mean = _get_gradients_mean() + mean = _get_mirror_mean() degree = _get_device_num() - self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree) - self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree) + + parameter_length = len(self.feature_map) + self.grad_reducer_thorA = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree) + self.grad_reducer_thorG = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree) self.weight_decay = weight_decay self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.update_gradient = P.UpdateThorGradient(split_dim=128) @@ -191,12 +193,13 @@ class THOR(Optimizer): 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0] - mean = _get_gradients_mean() + mean = _get_mirror_mean() degree = _get_device_num() - self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) - self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) - self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) - self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) + parameter_length = len(self.feature_map) + self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree) + self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree) + self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree) + self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree) self.matrix_A_inv = () self.matrix_G_inv = () self.matrix_max_inv = () diff --git a/model_zoo/official/cv/resnet_thor/train.py b/model_zoo/official/cv/resnet_thor/train.py index 5d8ce2f38f8dc3886af225b0cfbf7e02c50b98c9..29d4e58e32169fb3ae5953340621751be2a13392 100644 --- a/model_zoo/official/cv/resnet_thor/train.py +++ b/model_zoo/official/cv/resnet_thor/train.py @@ -95,11 +95,7 @@ if __name__ == '__main__': context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + auto_parallel_context().set_all_reduce_fusion_split_indices([107]) init() # GPU target else: diff --git a/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py b/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py index 02c37b1127d8642626691e43a1440d6ab418fe90..aef0766571b032fb93416e81681364efb73a8dbf 100644 --- a/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py @@ -12,150 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""grad_reducer_thor""" -import mindspore.common.dtype as mstype -from mindspore.communication.management import GlobalComm, get_group_size +"""grad reducer cell for distributed training""" from mindspore.nn.cell import Cell +from mindspore.communication.management import GlobalComm, get_group_size from mindspore.ops import functional as F, composite as C, operations as P -from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp +from mindspore.ops.operations.comm_ops import AllReduce +import mindspore.common.dtype as mstype reduce_opt = C.MultitypeFuncGraph("reduce_opt") -_all_reduce_A = AllReduce() +def _init_allreduce_operators(length, split_indices): + """ initialize allreduce communication operators""" + indices = split_indices[0] + fusion = split_indices[1] + op_list = () + j = 0 + for i in range(length): + if j <= len(indices)-1: + temp = indices[j] + else: + temp = length + if i >= temp: + j = j + 1 + fusion = fusion + 1 + op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) + op.add_prim_attr('fusion', fusion) + op_list = op_list + (op,) + return op_list + + +@reduce_opt.register("Function", "Number", "Function", "Tensor") +def _tensors_allreduce_mean(mul, degree, allreduce, parameters): + """ + Apply allreduce on parameters. -def _init_optimizer_allreduce(group): - global _all_reduce_A - _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) - _all_reduce_A.add_prim_attr('fusion', group) - + Args: + mul(Primitive): The mul operator for parameters. + degree (int): The mean coefficient. + allreduce (Primitive): The communication operator for parameters. + parameters (Tensor): The parameters before operation. -@reduce_opt.register("Function", "Number", "Tensor") -def _tensors_allreduce_mean(mul, degree, grad): - degree = F.scalar_cast(degree, F.dtype(grad)) - grad = _all_reduce_A(grad) + Returns: + Tensor, the parameters after operation. + """ + degree = F.scalar_cast(degree, F.dtype(parameters)) + parameters = allreduce(parameters) cast_op = P.Cast() - return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) - - -@reduce_opt.register("Bool", "Tensor") -def _tensors_allreduce(allreduce_filter, grad): - if allreduce_filter: - return _all_reduce_A(grad) - return grad + return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters))) _get_datatype = C.MultitypeFuncGraph("_get_datatype") @_get_datatype.register("Tensor") -def _tensors_get_datatype(grad): +def _tensors_get_datatype(parameters): """ - Acquire gradient datatype. + Acquire parameters datatype. Args: - grad (Tensor): The gradient tensor before operation. + parameters (Tensor): The parameters before operation. Returns: - mstype, the datatype of gradient. + mstype, the datatype of parameters. """ - return F.dtype(grad) + return F.dtype(parameters) _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") @_cast_datatype.register("TypeType", "Tensor") -def _tensors_cast_datatype(datatype, grad): +def _tensors_cast_datatype(datatype, parameters): """ - Cast gradient to datatype. + Cast parameters to datatype. Args: - datatype (mstype): the destination datatype of gradient. - grad (Tensor): The gradient tensor before operation. + datatype (mstype): the destination datatype of parameters. + parameters (Tensor): The parameters before operation. Returns: - Tensor, the gradient tensor after operation. + Tensor, the parameters after operation. """ - return F.cast(grad, datatype) + return F.cast(parameters, datatype) class DistributedGradReducerThor(Cell): """ A distributed optimizer. - Constructs a gradient reducer Cell, which applies communication and average operations on - single-process gradient values. + Constructs a parameters reducer Cell, which applies communication and average operations on + single-process parameters values. Args: - parameters (list): the parameters to be updated. - group (int): the different group to allreduce. - mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. + parameter_length (int): length of the parameters to be updated. + split_indices(tuple): parameter split indices. + mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False. degree (int): The mean coefficient. Usually it equals to device number. Default: None. Raises: ValueError: If degree is not a int or less than 0. - - Examples: - >>> from mindspore.communication import init, get_group_size - >>> from mindspore.ops import composite as C - >>> from mindspore.ops import operations as P - >>> from mindspore.ops import functional as F - >>> from mindspore import context - >>> from mindspore import nn - >>> from mindspore import ParameterTuple - >>> from mindspore.context import ParallelMode - >>> - >>> device_id = int(os.environ["DEVICE_ID"]) - >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, - >>> device_id=int(device_id), enable_hccl=True) - >>> init() - >>> context.reset_auto_parallel_context() - >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) - >>> - >>> - >>> class TrainingWrapper(nn.Cell): - >>> def __init__(self, network, optimizer, sens=1.0): - >>> super(TrainingWrapper, self).__init__(auto_prefix=False) - >>> self.network = network - >>> self.network.add_flags(defer_inline=True) - >>> self.weights = ParameterTuple(network.trainable_params()) - >>> self.optimizer = optimizer - >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True) - >>> self.sens = sens - >>> self.reducer_flag = False - >>> self.grad_reducer = None - >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, - >>> ParallelMode.HYBRID_PARALLEL]: - >>> self.reducer_flag = True - >>> if self.reducer_flag: - >>> mean = context.get_auto_parallel_context("gradients_mean") - >>> if mean.get_device_num_is_set(): - >>> degree = context.get_auto_parallel_context("device_num") - >>> else: - >>> degree = get_group_size() - >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) - >>> - >>> def construct(self, *args): - >>> weights = self.weights - >>> loss = self.network(*args) - >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - >>> grads = self.grad(self.network, weights)(*args, sens) - >>> if self.reducer_flag: - >>> # apply grad reducer on grads - >>> grads = self.grad_reducer(grads) - >>> return F.depend(loss, self.optimizer(grads)) - >>> - >>> network = Net() - >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> train_cell = TrainingWrapper(network, optimizer) - >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) - >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) - >>> grads = train_cell(inputs, label) """ - def __init__(self, parameters, group, mean=True, degree=None): + def __init__(self, parameter_length, split_indices, mean=True, degree=None): super(DistributedGradReducerThor, self).__init__(auto_prefix=False) self.hyper_map = C.HyperMap() self.mul = P.Mul() @@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell): raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") self.degree = degree self.mean = mean - self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) - _init_optimizer_allreduce(group) - - def construct(self, grads): - # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the - # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, - # and cast back after the operation. - datatypes = self.hyper_map(F.partial(_get_datatype), grads) - grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) - - if self.mean: - new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) - else: - new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) - - new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) - return new_grad + self.op_list = _init_allreduce_operators(parameter_length, split_indices) + + def construct(self, parameters): + datatypes = self.hyper_map(F.partial(_get_datatype), parameters) + parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters) + new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters) + new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters) + return new_parameters diff --git a/tests/st/networks/models/resnet50/src_thor/thor.py b/tests/st/networks/models/resnet50/src_thor/thor.py index b5b1faa1d50a1aaf1343bda0ab77aadd81ce8700..4f293bd4bade5dd547bd15cddee9ac4ad2281f1c 100644 --- a/tests/st/networks/models/resnet50/src_thor/thor.py +++ b/tests/st/networks/models/resnet50/src_thor/thor.py @@ -89,10 +89,11 @@ class THOR(Optimizer): 1.0] mean = _get_gradients_mean() degree = _get_device_num() - self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) - self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) - self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) - self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) + parameter_length = len(self.feature_map) + self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree) + self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree) + self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree) + self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree) self.matrix_A_inv = () self.matrix_G_inv = () self.matrix_max_inv = () diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 28ed8b54893197b37684265f25097c08f3170d22..3efdf78310ca62d8946ec71d63374f301c73b154 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): if enable_hccl: context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + auto_parallel_context().set_all_reduce_fusion_split_indices([107]) init() # network