diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100644 new mode 100755 index 9f3af174f607792eb416a6648cc1ff76818c2ecd..914e27d6f1f5e689c7c97b96d1875f6ba676eb00 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; } message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; + optional bool hybrid_dp = 2 [ default = false ]; + optional int32 sharding_group_size = 3 [ default = 8 ]; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index cf6ab514b0bfe6a7dd031acb189412fd088f5bfa..03b36262a4fb1e095eb17fa57bf27b5c9f3cf74c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -71,7 +71,11 @@ class FP16Utils(object): return inserted_op_num @staticmethod - def prune_fp16(block, shard, reduced_grads_to_param, nrings): + def prune_fp16(block, shard, reduced_grads_to_param, ring_id): + """ + 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard + 2. revise amp inifine grad checking for sharding + """ # remove cast for idx, op in reversed(list(enumerate(block.ops))): if not FP16Utils.is_fp32_cast_op(block, op): @@ -79,9 +83,9 @@ class FP16Utils(object): output_name = op.desc.output_arg_names()[0] param_name = output_name.strip("@GRAD") if param_name not in shard.global_params: - raise ValueError("Input 'X' of check_finite_and_unscale must" - "be grads, but {} is not a grad".format( - input_name)) + raise ValueError("Output 'X' of cast_op must be a grad of" + "model param, but {} is not a grad".format( + output_name)) if output_name in reduced_grads_to_param: continue if shard.has_param(param_name): @@ -137,10 +141,12 @@ class FP16Utils(object): type='c_allreduce_max', inputs={'X': inf_var_fp32}, outputs={'Out': inf_var_fp32}, - attrs={'ring_id': 0, + attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Optimize}) - comm_op_num = insert_sync_comm_ops( - block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) + + comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, + ring_id, [inf_var_fp32]) + block._insert_op_without_sync( update_loss_scaling_op_idx + 3 + comm_op_num, type='cast', diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index afa46f43fc0fe3ba10560b8fcad8504d370cf88b..c6aee792fcf745a6ec51b3c4d1945415bfd9324f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole class GradientClipHelper(object): - def __init__(self): - pass + def __init__(self, sharding_ring_id): + self.sharding_ring_id = sharding_ring_id def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/gradient_clip") def prune_gradient_clip(self, block, shard): + """ + prune gradient_clip related ops for params that not belong to cur shard + prune: square, reduce_sum, elementwise_mul + keep: sum, sqrt, elementwise_max, elementwise_div + """ deperated_vars = set() deperate_op_idx = set() for idx, op in enumerate(block.ops): @@ -75,8 +80,10 @@ class GradientClipHelper(object): type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, - attrs={'ring_id': 0, - OP_ROLE_KEY: OpRole.Optimize}) + attrs={ + 'ring_id': self.sharding_ring_id, + OP_ROLE_KEY: OpRole.Optimize + }) block._insert_op_without_sync( idx + 1, type='c_sync_calc_stream', diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py index 7348e5f6d1445abfd603c7f8033302ca5b276844..70753b59ccc318a25661e084bd305d7d76b0e2a6 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -43,6 +43,7 @@ class ProgramDeps(object): return None def _build_deps(self, ): + for var_name in self._start_vars: self._var_to_use_op[var_name] = [] self._var_to_generate_op[var_name] = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py index 27c63fc406fcbfacb47ee2d33156ba1f3dda03ec..92e36e0ec1fff352cbb88eaea7024200414c4389 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py @@ -124,6 +124,14 @@ class Shard(object): return True return False + def filter_grads(self, grads): + grads_in_shard = [] + for grad in grads: + param = grad.split("@")[0] + if self.has_param(param): + grads_in_shard.append(grad) + return grads_in_shard + class ProgramSegment(object): def __init__(self, block): diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index b5c34f87cdf22534054fb5c8499734c59b061b18..ad1cd4f60826bbf434294114d1982cb4beb3f00a 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -78,52 +78,137 @@ def check_broadcast(block): return -def check_allreduce_sum(block): +def check_allreduce_sum(block, shard, dp_ring_id=-1): """ - if a Var is allreduced, the op order should be: - - 0: op that generate Var - - 1: sync_calc - - 2: allreduce_sum op - - 3: sync_comm - - 4: op that use Var + the op order should be: + grad: + - 0: op that generate Var + - 1: sync_calc + - 2: allreduce_sum_sharding + - 3: sync_comm + - 4: allreuce_sum_dp (dp_grads) + - 5: sync_comm (dp_grads) + - 6: op that use Var (dp_grads & sum) """ - var_status = {} - for op in block.ops: + vars_status = {} + dp_grads_status = {} + idx_last_grad_allreduce = -1 + idx_amp_allreduce = -1 + idx_gradient_clip_allreduce = -1 + for idx, op in enumerate(block.ops): if op.type == "c_allreduce_sum": + ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0] - var_status[var_name] = -1 + param = var_name.split("@")[0] + + assert 'sum' in var_name or ("@GRAD" in var_name) + if 'sum' in var_name or (not shard.has_param(param)): + vars_status[var_name] = -1 + else: + dp_grads_status[var_name] = -1 + + if ring_id != 0: + assert shard.has_param(param) + assert ring_id == dp_ring_id + + if "sum" in var_name: + idx_amp_allreduce = idx + elif "@GRAD": + idx_last_grad_allreduce = idx + + if op.type == "c_allreduce_max": + idx_gradient_clip_allreduce = idx for op in block.ops: if op.type == "c_sync_calc_stream": - for var_name in var_status: - if var_name in var_status and var_status[var_name] == 0: - var_status[var_name] = 1 + for var_name in vars_status: + if var_name in vars_status and vars_status[var_name] == 0: + vars_status[var_name] = 1 + for var_name in dp_grads_status: + if var_name in dp_grads_status and dp_grads_status[ + var_name] == 0: + dp_grads_status[var_name] = 1 + elif op.type == "c_allreduce_sum": var_name = op.desc.input_arg_names()[0] - if var_status[var_name] == -1: - raise ValueError("{} is not generated, but you are" - "trying to all-reduce it".format(var_name)) - if var_status[var_name] == 0: - raise ValueError("There should be a sync_calc op " - "after generate Var: {} and before the" - "c_allreduce_sum op".format(var_name)) - assert (var_status[var_name] == 1) - var_status[var_name] = 2 + ring_id = op.desc.attr("ring_id") + if ring_id == 0: + if var_name in vars_status: + _status = vars_status[var_name] + else: + _status = dp_grads_status[var_name] + if _status == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format(var_name)) + if _status == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (_status == 1) + if var_name in vars_status: + vars_status[var_name] = 2 + else: + dp_grads_status[var_name] = 2 + else: + assert ring_id == dp_ring_id + param = var_name.split("@")[0] + assert shard.has_param(param) + assert dp_grads_status[var_name] == 3 + dp_grads_status[var_name] = 4 + elif op.type == "c_sync_comm_stream": - for var_name in op.desc.input_arg_names(): - if var_name in var_status and var_status[var_name] == 2: - var_status[var_name] = 3 + var_name = op.desc.input_arg_names()[0] + ring_id = op.desc.attr("ring_id") + if ring_id == 0: + for var_name in op.desc.input_arg_names(): + if var_name in vars_status: + assert vars_status[var_name] == 2 + vars_status[var_name] = 3 + elif var_name in dp_grads_status: + assert dp_grads_status[var_name] == 2 + dp_grads_status[var_name] = 3 + else: + for var_name in op.desc.input_arg_names(): + param = var_name.split("@")[0] + assert ring_id == dp_ring_id + assert shard.has_param(param) + assert dp_grads_status[var_name] == 4 + dp_grads_status[var_name] = 5 else: for input_name in op.desc.input_arg_names(): - if input_name in var_status: - if var_status[input_name] != 3: + if input_name in vars_status: + if vars_status[input_name] != 3: raise ValueError("There should be a sync_comm op " "after allreduce the Var: {}".format( - var_name)) + input_name)) + if input_name in dp_grads_status: + if dp_ring_id == -1: + if dp_grads_status[input_name] != 3: + raise ValueError("There should be a sync_comm op " + "after allreduce the Var: {}". + format(input_name)) + else: + if dp_grads_status[input_name] != 5: + raise ValueError( + "The grad in shard should be allreduce and sync" + "twice before usage {}".format(input_name)) + for output_name in op.desc.output_arg_names(): - if output_name in var_status and \ - var_status[output_name] == -1: - var_status[output_name] = 0 + if output_name in vars_status and \ + vars_status[output_name] == -1: + vars_status[output_name] = 0 + if output_name in dp_grads_status and \ + dp_grads_status[output_name] == -1: + dp_grads_status[output_name] = 0 + + # check sharding with amp + if idx_amp_allreduce != -1: + assert idx_amp_allreduce > idx_last_grad_allreduce + + # check sharding with gradient_clip_by_global_norm + if idx_gradient_clip_allreduce != -1: + assert idx_gradient_clip_allreduce > idx_last_grad_allreduce + return @@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars): return -def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): +def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars): """ - _insert_sync_comm_ops + insert sync_comm_op for single var """ op_role = get_valid_op_role(block, insert_idx) - for i in range(nrings): - block._insert_op_without_sync( - insert_idx, - type='c_sync_comm_stream', - inputs={'X': comm_dep_vars}, - outputs={'Out': comm_dep_vars}, - attrs={'ring_id': i, - OP_ROLE_KEY: op_role}) - return nrings + block._insert_op_without_sync( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: op_role}) + return 1 + + +def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): + """ + insert sync_comm_op for vars + """ + op_role = get_valid_op_role(block, insert_idx) + block._insert_op_without_sync( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': int(ring_id), + OP_ROLE_KEY: op_role}) + return 1 def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): @@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops): return -def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): +def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): """ _add_allreduce_ops """ - ring_id = -1 for var in allreduce_vars: - ring_id = (ring_id + 1) % nrings block._insert_op_without_sync( insert_idx, type='c_allreduce_sum', @@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): outputs={'Out': var}, attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}) + return -def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): +def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops """ - ring_id = -1 op_role = get_valid_op_role(block, insert_idx) for broadcast_name, root_device in broadcast2root: - ring_id = (ring_id + 1) % nrings block._insert_op_without_sync( insert_idx, type='c_broadcast', @@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): 'root': root_device, OP_ROLE_KEY: op_role }) + return diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index a449821f8c21227bd1c22c860202408616597d29..a7f704361d31af5c1535259c62f13ea0cc3d0c3b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * - +import logging from functools import reduce __all__ = ["ShardingOptimizer"] @@ -37,6 +37,8 @@ class ShardingOptimizer(MetaOptimizerBase): self.meta_optimizers_white_list = [ "RecomputeOptimizer", "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -69,9 +71,14 @@ class ShardingOptimizer(MetaOptimizerBase): startup_program=None, parameter_list=None, no_grad_set=None): - self._nrings = self.user_defined_strategy.nccl_comm_num + # TODO: (JZ-LIANG) support multiple comm in future + # self._nrings = self.user_defined_strategy.nccl_comm_num + self._nrings_sharding = 1 + self._nrings_dp = 1 self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ "fuse_broadcast_MB"] + self.hybrid_dp = self.user_defined_strategy.sharding_configs[ + "hybrid_dp"] if self.inner_opt is None: raise ValueError( @@ -108,28 +115,38 @@ class ShardingOptimizer(MetaOptimizerBase): # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block) + check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._wait() return optimize_ops, params_grads def _set_up(self, params_grads): # step 1: initialize nccl - worker_idx = self.role_maker._worker_index() - endpoints = self.role_maker._get_trainer_endpoints() - current_endpoint = endpoints[worker_idx] + self.global_word_size = self.role_maker._worker_num() + self.global_rank = self.role_maker._worker_index() + self.endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.endpoints[self.global_rank] self._collective_helper = CollectiveHelper(self.role_maker, - self._nrings) - for ring_id in range(self._nrings): + self._nrings_sharding) + # config sharding & dp groups + self._init_comm() + # sharding + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.sharding_group_endpoints, self.sharding_rank, + self.sharding_ring_id, True) + # dp + if self.hybrid_dp: self._collective_helper._init_communicator( - self._startup_program, current_endpoint, endpoints, worker_idx, - ring_id, None) + self._startup_program, self.current_endpoint, + self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) + startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() # step 2: split params self._params = set([x[0].name for x in params_grads]) - self._shard.setup(params_grads, worker_idx, - self.role_maker._worker_num()) + self._shard.setup(params_grads, self.sharding_rank, + self.sharding_group_size) # step 3: get broadcast vars self._broadcast_vars = self._shard.find_broadcast_params( @@ -208,12 +225,18 @@ class ShardingOptimizer(MetaOptimizerBase): """ calculate deps from allredce op to optimize op, remove ops and vars not needed in this worker + + 1. prune regularization (weight decay) + 2. prune cast_fp32_to_fp16; update amp_infine_checking + 3. prune gradient_clip related; update global_norm_sum + 4. prune optimizer op + param + gradient + """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self._nrings) - gradientclip_helper = GradientClipHelper() + self.sharding_ring_id) + gradientclip_helper = GradientClipHelper(self.sharding_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) # build prog deps @@ -226,6 +249,7 @@ class ShardingOptimizer(MetaOptimizerBase): output_name = output_names[0] reduced_grads.append(output_name) + # prune optimizer state and param pruned_opti_vars = [] for var_name in list(block.vars.keys()): if self._shard.is_opti_var(var_name) and \ @@ -273,6 +297,8 @@ class ShardingOptimizer(MetaOptimizerBase): op.desc.set_input('Input', reversed_input_vars) op.desc.set_output('Out', reversed_output_vars) else: + # if all outputs of this op are in _should_removed_var + # _should_removed_var: opt state not cur shard if program_deps.should_remove_op(idx): program_deps.remove_op(idx) @@ -283,16 +309,22 @@ class ShardingOptimizer(MetaOptimizerBase): """ _add_broadcast_allreduce """ - ring_id = -1 if len(self._segments) < 1: return - + # sharding if self._segments[-1]._allreduce_vars: + shard_allredue_vars = self._shard.filter_grads(self._segments[-1] + ._allreduce_vars) + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) + insert_allreduce_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self._nrings, + self.sharding_ring_id, self._segments[-1]._allreduce_vars) insert_allreduce_ops(block, self._segments[-1]._end_idx, - self._nrings, + self.sharding_ring_id, self._segments[-1]._allreduce_vars) for idx, segment in reversed(list(enumerate(self._segments))): @@ -331,13 +363,21 @@ class ShardingOptimizer(MetaOptimizerBase): segment, 0) # step2: add Sync ops - comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] - if len(comm_dep_vars) > 0: - insert_sync_comm_ops( - block, - segment._end_idx, - self._nrings, - comm_dep_vars, ) + shard_allredue_vars = self._shard.filter_grads(allreduce_vars) + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id, + shard_allredue_vars) + + broad_cast_vars = [x[0] for x in broadcast_vars] + if len(broad_cast_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, broad_cast_vars) + else: + comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + if len(comm_dep_vars) > 0: + insert_sync_comm_ops(block, segment._end_idx, + self.sharding_ring_id, comm_dep_vars) + calc_dep_vars = fill_constant_vars + [ k for k, v in cast_ops.items() ] + self._segments[idx]._allreduce_vars @@ -354,21 +394,27 @@ class ShardingOptimizer(MetaOptimizerBase): insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops - insert_broadcast_ops(block, segment._start_idx, self._nrings, - broadcast_vars) - + insert_broadcast_ops(block, segment._start_idx, + self.sharding_ring_id, broadcast_vars) # step6: add all_reduce ops - insert_allreduce_ops(block, segment._start_idx, self._nrings, - allreduce_vars) + # dp + if self.hybrid_dp and len(shard_allredue_vars) >= 1: + insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, + shard_allredue_vars) + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) + # sharding + insert_allreduce_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) block._sync_with_cpp() if self._segments[0]._broadcast_vars: - insert_sync_comm_ops( - block, self._segments[0]._start_idx, self._nrings, - [x[0] for x in self._segments[0]._broadcast_vars]) + broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] + insert_sync_comm_ops(block, self._segments[0]._start_idx, + self.sharding_ring_id, broadcast_vars) insert_broadcast_ops(block, self._segments[0]._start_idx, - self._nrings, + self.sharding_ring_id, self._segments[0]._broadcast_vars) fill_constant_vars = [] @@ -409,3 +455,60 @@ class ShardingOptimizer(MetaOptimizerBase): continue block._remove_var(var_name, sync=False) block._sync_with_cpp() + + def _init_comm(self): + + if self.hybrid_dp: + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + "sharding_group_size"] + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank % self.sharding_group_size + + self.dp_group_size = self.global_word_size // self.sharding_group_size + self.dp_rank = self.global_rank // self.sharding_group_size + self.dp_ring_id = self.sharding_rank + 1 + + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self.sharding_group_size) == self.dp_rank + ] + self.dp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size) == self.sharding_rank + ] + assert self.global_word_size > self.sharding_group_size, \ + "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) + assert self.global_word_size % self.sharding_group_size == 0, \ + "global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) + assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \ + "global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format( + self.global_word_size, + self.sharding_group_size, + self.dp_group_size) + + logging.info("Using Sharing&DP mode !") + else: + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank + self.sharding_group_size = self.role_maker._worker_num() + self.sharding_group_endpoints = self.endpoints + self.dp_ring_id = -1 + self.dp_rank = -1 + self.dp_group_size = None + self.dp_group_endpoints = None + + logging.info("Using Sharing alone mode !") + + logging.info("global word size: {}".format(self.global_word_size)) + logging.info("global rank: {}".format(self.global_rank)) + logging.info("sharding group_size: {}".format(self.sharding_group_size)) + logging.info("sharding rank: {}".format(self.sharding_rank)) + logging.info("dp group size: {}".format(self.dp_group_size)) + logging.info("dp rank: {}".format(self.dp_rank)) + logging.info("current endpoint: {}".format(self.current_endpoint)) + logging.info("sharding group endpoints: {}".format( + self.sharding_group_endpoints)) + logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) + logging.info("global word endpoints: {}".format(self.endpoints)) + + return