未验证 提交 ab04bf01 编写于 作者: J JZ-LIANG 提交者: GitHub

[2.0/cherrypick] cherry-pick Sharding PR:29518 (#29593)

* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
上级 d82b0300
...@@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; } ...@@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig { message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -71,7 +71,11 @@ class FP16Utils(object): ...@@ -71,7 +71,11 @@ class FP16Utils(object):
return inserted_op_num return inserted_op_num
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast # remove cast
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
...@@ -79,9 +83,9 @@ class FP16Utils(object): ...@@ -79,9 +83,9 @@ class FP16Utils(object):
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must" raise ValueError("Output 'X' of cast_op must be a grad of"
"be grads, but {} is not a grad".format( "model param, but {} is not a grad".format(
input_name)) output_name))
if output_name in reduced_grads_to_param: if output_name in reduced_grads_to_param:
continue continue
if shard.has_param(param_name): if shard.has_param(param_name):
...@@ -137,10 +141,12 @@ class FP16Utils(object): ...@@ -137,10 +141,12 @@ class FP16Utils(object):
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0, attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize}) OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast', type='cast',
......
...@@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self): def __init__(self, sharding_ring_id):
pass self.sharding_ring_id = sharding_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard): def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
...@@ -75,8 +80,10 @@ class GradientClipHelper(object): ...@@ -75,8 +80,10 @@ class GradientClipHelper(object):
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={'ring_id': 0, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_sync_calc_stream', type='c_sync_calc_stream',
......
...@@ -43,6 +43,7 @@ class ProgramDeps(object): ...@@ -43,6 +43,7 @@ class ProgramDeps(object):
return None return None
def _build_deps(self, ): def _build_deps(self, ):
for var_name in self._start_vars: for var_name in self._start_vars:
self._var_to_use_op[var_name] = [] self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = [] self._var_to_generate_op[var_name] = []
......
...@@ -124,6 +124,14 @@ class Shard(object): ...@@ -124,6 +124,14 @@ class Shard(object):
return True return True
return False return False
def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard
class ProgramSegment(object): class ProgramSegment(object):
def __init__(self, block): def __init__(self, block):
......
...@@ -78,52 +78,137 @@ def check_broadcast(block): ...@@ -78,52 +78,137 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block): def check_allreduce_sum(block, shard, dp_ring_id=-1):
""" """
if a Var is allreduced, the op order should be: the op order should be:
grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum op - 2: allreduce_sum_sharding
- 3: sync_comm - 3: sync_comm
- 4: op that use Var - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
""" """
var_status = {} vars_status = {}
for op in block.ops: dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1 param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
for op in block.ops: for op in block.ops:
if op.type == "c_sync_calc_stream": if op.type == "c_sync_calc_stream":
for var_name in var_status: for var_name in vars_status:
if var_name in var_status and var_status[var_name] == 0: if var_name in vars_status and vars_status[var_name] == 0:
var_status[var_name] = 1 vars_status[var_name] = 1
for var_name in dp_grads_status:
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1: ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are" raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name)) "trying to all-reduce it".format(var_name))
if var_status[var_name] == 0: if _status == 0:
raise ValueError("There should be a sync_calc op " raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the" "after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name)) "c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1) assert (_status == 1)
var_status[var_name] = 2 if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2: if var_name in vars_status:
var_status[var_name] = 3 assert vars_status[var_name] == 2
vars_status[var_name] = 3
elif var_name in dp_grads_status:
assert dp_grads_status[var_name] == 2
dp_grads_status[var_name] = 3
else:
for var_name in op.desc.input_arg_names():
param = var_name.split("@")[0]
assert ring_id == dp_ring_id
assert shard.has_param(param)
assert dp_grads_status[var_name] == 4
dp_grads_status[var_name] = 5
else: else:
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in var_status: if input_name in vars_status:
if var_status[input_name] != 3: if vars_status[input_name] != 3:
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
var_name)) input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".
format(input_name))
else:
if dp_grads_status[input_name] != 5:
raise ValueError(
"The grad in shard should be allreduce and sync"
"twice before usage {}".format(input_name))
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name in var_status and \ if output_name in vars_status and \
var_status[output_name] == -1: vars_status[output_name] == -1:
var_status[output_name] = 0 vars_status[output_name] = 0
if output_name in dp_grads_status and \
dp_grads_status[output_name] == -1:
dp_grads_status[output_name] = 0
# check sharding with amp
if idx_amp_allreduce != -1:
assert idx_amp_allreduce > idx_last_grad_allreduce
# check sharding with gradient_clip_by_global_norm
if idx_gradient_clip_allreduce != -1:
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce
return return
...@@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars): ...@@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
return return
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
""" """
_insert_sync_comm_ops insert sync_comm_op for single var
""" """
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings):
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': comm_dep_vars}, inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars}, outputs={'Out': comm_dep_vars},
attrs={'ring_id': i, attrs={'ring_id': ring_id,
OP_ROLE_KEY: op_role}) OP_ROLE_KEY: op_role})
return nrings return 1
def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': int(ring_id),
OP_ROLE_KEY: op_role})
return 1
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
...@@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops): ...@@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return return
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
ring_id = -1
for var in allreduce_vars: for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_allreduce_sum', type='c_allreduce_sum',
...@@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): ...@@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
outputs={'Out': var}, outputs={'Out': var},
attrs={'ring_id': ring_id, attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
return return
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
""" """
ring_id = -1
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root: for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_broadcast', type='c_broadcast',
...@@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): ...@@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
'root': root_device, 'root': root_device,
OP_ROLE_KEY: op_role OP_ROLE_KEY: op_role
}) })
return return
......
...@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor ...@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging
from functools import reduce from functools import reduce
__all__ = ["ShardingOptimizer"] __all__ = ["ShardingOptimizer"]
...@@ -37,6 +37,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -37,6 +37,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.meta_optimizers_white_list = [ self.meta_optimizers_white_list = [
"RecomputeOptimizer", "RecomputeOptimizer",
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -69,9 +71,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -69,9 +71,14 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
self._nrings = self.user_defined_strategy.nccl_comm_num # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[
"fuse_broadcast_MB"] "fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -108,28 +115,38 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -108,28 +115,38 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy # check op dependecy
check_broadcast(main_block) check_broadcast(main_block)
check_allreduce_sum(main_block) check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
def _set_up(self, params_grads): def _set_up(self, params_grads):
# step 1: initialize nccl # step 1: initialize nccl
worker_idx = self.role_maker._worker_index() self.global_word_size = self.role_maker._worker_num()
endpoints = self.role_maker._get_trainer_endpoints() self.global_rank = self.role_maker._worker_index()
current_endpoint = endpoints[worker_idx] self.endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker, self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings) self._nrings_sharding)
for ring_id in range(self._nrings): # config sharding & dp groups
self._init_comm()
# sharding
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# dp
if self.hybrid_dp:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints, worker_idx, self._startup_program, self.current_endpoint,
ring_id, None) self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# step 2: split params # step 2: split params
self._params = set([x[0].name for x in params_grads]) self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, worker_idx, self._shard.setup(params_grads, self.sharding_rank,
self.role_maker._worker_num()) self.sharding_group_size)
# step 3: get broadcast vars # step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params( self._broadcast_vars = self._shard.find_broadcast_params(
...@@ -208,12 +225,18 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -208,12 +225,18 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
calculate deps from allredce op to optimize op, calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker remove ops and vars not needed in this worker
1. prune regularization (weight decay)
2. prune cast_fp32_to_fp16; update amp_infine_checking
3. prune gradient_clip related; update global_norm_sum
4. prune optimizer op + param + gradient
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, self._shard)
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self._nrings) self.sharding_ring_id)
gradientclip_helper = GradientClipHelper() gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard) gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps # build prog deps
...@@ -226,6 +249,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -226,6 +249,7 @@ class ShardingOptimizer(MetaOptimizerBase):
output_name = output_names[0] output_name = output_names[0]
reduced_grads.append(output_name) reduced_grads.append(output_name)
# prune optimizer state and param
pruned_opti_vars = [] pruned_opti_vars = []
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \ if self._shard.is_opti_var(var_name) and \
...@@ -273,6 +297,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -273,6 +297,8 @@ class ShardingOptimizer(MetaOptimizerBase):
op.desc.set_input('Input', reversed_input_vars) op.desc.set_input('Input', reversed_input_vars)
op.desc.set_output('Out', reversed_output_vars) op.desc.set_output('Out', reversed_output_vars)
else: else:
# if all outputs of this op are in _should_removed_var
# _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx): if program_deps.should_remove_op(idx):
program_deps.remove_op(idx) program_deps.remove_op(idx)
...@@ -283,16 +309,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -283,16 +309,22 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
_add_broadcast_allreduce _add_broadcast_allreduce
""" """
ring_id = -1
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self._nrings, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, insert_allreduce_ops(block, self._segments[-1]._end_idx,
self._nrings, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
...@@ -331,13 +363,21 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -331,13 +363,21 @@ class ShardingOptimizer(MetaOptimizerBase):
segment, 0) segment, 0)
# step2: add Sync ops # step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id,
shard_allredue_vars)
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, broad_cast_vars)
else:
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0: if len(comm_dep_vars) > 0:
insert_sync_comm_ops( insert_sync_comm_ops(block, segment._end_idx,
block, self.sharding_ring_id, comm_dep_vars)
segment._end_idx,
self._nrings,
comm_dep_vars, )
calc_dep_vars = fill_constant_vars + [ calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items() k for k, v in cast_ops.items()
] + self._segments[idx]._allreduce_vars ] + self._segments[idx]._allreduce_vars
...@@ -354,21 +394,27 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -354,21 +394,27 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_cast_ops(block, segment._end_idx, cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops # step5: add broadcast ops
insert_broadcast_ops(block, segment._start_idx, self._nrings, insert_broadcast_ops(block, segment._start_idx,
broadcast_vars) self.sharding_ring_id, broadcast_vars)
# step6: add all_reduce ops # step6: add all_reduce ops
insert_allreduce_ops(block, segment._start_idx, self._nrings, # dp
allreduce_vars) if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id,
shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# sharding
insert_allreduce_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
block._sync_with_cpp() block._sync_with_cpp()
if self._segments[0]._broadcast_vars: if self._segments[0]._broadcast_vars:
insert_sync_comm_ops( broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
block, self._segments[0]._start_idx, self._nrings, insert_sync_comm_ops(block, self._segments[0]._start_idx,
[x[0] for x in self._segments[0]._broadcast_vars]) self.sharding_ring_id, broadcast_vars)
insert_broadcast_ops(block, self._segments[0]._start_idx, insert_broadcast_ops(block, self._segments[0]._start_idx,
self._nrings, self.sharding_ring_id,
self._segments[0]._broadcast_vars) self._segments[0]._broadcast_vars)
fill_constant_vars = [] fill_constant_vars = []
...@@ -409,3 +455,60 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -409,3 +455,60 @@ class ShardingOptimizer(MetaOptimizerBase):
continue continue
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
def _init_comm(self):
if self.hybrid_dp:
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"]
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank % self.sharding_group_size
self.dp_group_size = self.global_word_size // self.sharding_group_size
self.dp_rank = self.global_rank // self.sharding_group_size
self.dp_ring_id = self.sharding_rank + 1
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.dp_rank
]
self.dp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \
"global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \
"global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format(
self.global_word_size,
self.sharding_group_size,
self.dp_group_size)
logging.info("Using Sharing&DP mode !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
return
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册