未验证 提交 d33d468f 编写于 作者: J JZ-LIANG 提交者: GitHub

[Sharding] add hybrid-dp feature (#29518)

* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
上级 1e72e032
......@@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
}
message AMPConfig {
......
......@@ -71,7 +71,11 @@ class FP16Utils(object):
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
......@@ -79,9 +83,9 @@ class FP16Utils(object):
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
output_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
......@@ -137,10 +141,12 @@ class FP16Utils(object):
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
......
......@@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self):
pass
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
......@@ -75,8 +80,10 @@ class GradientClipHelper(object):
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
......
......@@ -43,6 +43,7 @@ class ProgramDeps(object):
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
......
......@@ -124,6 +124,14 @@ class Shard(object):
return True
return False
def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard
class ProgramSegment(object):
def __init__(self, block):
......
......@@ -78,52 +78,137 @@ def check_broadcast(block):
return
def check_allreduce_sum(block):
def check_allreduce_sum(block, shard, dp_ring_id=-1):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
the op order should be:
grad:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum_sharding
- 3: sync_comm
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
"""
var_status = {}
for op in block.ops:
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
for op in block.ops:
if op.type == "c_sync_calc_stream":
for var_name in var_status:
if var_name in var_status and var_status[var_name] == 0:
var_status[var_name] = 1
for var_name in vars_status:
if var_name in vars_status and vars_status[var_name] == 0:
vars_status[var_name] = 1
for var_name in dp_grads_status:
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if var_status[var_name] == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1)
var_status[var_name] = 2
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream":
for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2:
var_status[var_name] = 3
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
vars_status[var_name] = 3
elif var_name in dp_grads_status:
assert dp_grads_status[var_name] == 2
dp_grads_status[var_name] = 3
else:
for var_name in op.desc.input_arg_names():
param = var_name.split("@")[0]
assert ring_id == dp_ring_id
assert shard.has_param(param)
assert dp_grads_status[var_name] == 4
dp_grads_status[var_name] = 5
else:
for input_name in op.desc.input_arg_names():
if input_name in var_status:
if var_status[input_name] != 3:
if input_name in vars_status:
if vars_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
var_name))
input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".
format(input_name))
else:
if dp_grads_status[input_name] != 5:
raise ValueError(
"The grad in shard should be allreduce and sync"
"twice before usage {}".format(input_name))
for output_name in op.desc.output_arg_names():
if output_name in var_status and \
var_status[output_name] == -1:
var_status[output_name] = 0
if output_name in vars_status and \
vars_status[output_name] == -1:
vars_status[output_name] = 0
if output_name in dp_grads_status and \
dp_grads_status[output_name] == -1:
dp_grads_status[output_name] = 0
# check sharding with amp
if idx_amp_allreduce != -1:
assert idx_amp_allreduce > idx_last_grad_allreduce
# check sharding with gradient_clip_by_global_norm
if idx_gradient_clip_allreduce != -1:
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce
return
......@@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
return
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
"""
_insert_sync_comm_ops
insert sync_comm_op for single var
"""
op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings):
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: op_role})
return nrings
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: op_role})
return 1
def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': int(ring_id),
OP_ROLE_KEY: op_role})
return 1
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
......@@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
......@@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
return
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
......@@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
'root': root_device,
OP_ROLE_KEY: op_role
})
return
......
......@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging
from functools import reduce
__all__ = ["ShardingOptimizer"]
......@@ -37,6 +37,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
......@@ -69,9 +71,14 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._nrings = self.user_defined_strategy.nccl_comm_num
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[
"fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
if self.inner_opt is None:
raise ValueError(
......@@ -108,28 +115,38 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
def _set_up(self, params_grads):
# step 1: initialize nccl
worker_idx = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[worker_idx]
self.global_word_size = self.role_maker._worker_num()
self.global_rank = self.role_maker._worker_index()
self.endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings)
for ring_id in range(self._nrings):
self._nrings_sharding)
# config sharding & dp groups
self._init_comm()
# sharding
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# dp
if self.hybrid_dp:
self._collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints, worker_idx,
ring_id, None)
self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
# step 2: split params
self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, worker_idx,
self.role_maker._worker_num())
self._shard.setup(params_grads, self.sharding_rank,
self.sharding_group_size)
# step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params(
......@@ -208,12 +225,18 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
1. prune regularization (weight decay)
2. prune cast_fp32_to_fp16; update amp_infine_checking
3. prune gradient_clip related; update global_norm_sum
4. prune optimizer op + param + gradient
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self._nrings)
gradientclip_helper = GradientClipHelper()
self.sharding_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps
......@@ -226,6 +249,7 @@ class ShardingOptimizer(MetaOptimizerBase):
output_name = output_names[0]
reduced_grads.append(output_name)
# prune optimizer state and param
pruned_opti_vars = []
for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \
......@@ -273,6 +297,8 @@ class ShardingOptimizer(MetaOptimizerBase):
op.desc.set_input('Input', reversed_input_vars)
op.desc.set_output('Out', reversed_output_vars)
else:
# if all outputs of this op are in _should_removed_var
# _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
......@@ -283,16 +309,22 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
_add_broadcast_allreduce
"""
ring_id = -1
if len(self._segments) < 1:
return
# sharding
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self._nrings,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self._nrings,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
for idx, segment in reversed(list(enumerate(self._segments))):
......@@ -331,13 +363,21 @@ class ShardingOptimizer(MetaOptimizerBase):
segment, 0)
# step2: add Sync ops
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(
block,
segment._end_idx,
self._nrings,
comm_dep_vars, )
shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id,
shard_allredue_vars)
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, broad_cast_vars)
else:
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, comm_dep_vars)
calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items()
] + self._segments[idx]._allreduce_vars
......@@ -354,21 +394,27 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
insert_broadcast_ops(block, segment._start_idx, self._nrings,
broadcast_vars)
insert_broadcast_ops(block, segment._start_idx,
self.sharding_ring_id, broadcast_vars)
# step6: add all_reduce ops
insert_allreduce_ops(block, segment._start_idx, self._nrings,
allreduce_vars)
# dp
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id,
shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# sharding
insert_allreduce_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
block._sync_with_cpp()
if self._segments[0]._broadcast_vars:
insert_sync_comm_ops(
block, self._segments[0]._start_idx, self._nrings,
[x[0] for x in self._segments[0]._broadcast_vars])
broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
insert_sync_comm_ops(block, self._segments[0]._start_idx,
self.sharding_ring_id, broadcast_vars)
insert_broadcast_ops(block, self._segments[0]._start_idx,
self._nrings,
self.sharding_ring_id,
self._segments[0]._broadcast_vars)
fill_constant_vars = []
......@@ -409,3 +455,60 @@ class ShardingOptimizer(MetaOptimizerBase):
continue
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
def _init_comm(self):
if self.hybrid_dp:
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"]
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank % self.sharding_group_size
self.dp_group_size = self.global_word_size // self.sharding_group_size
self.dp_rank = self.global_rank // self.sharding_group_size
self.dp_ring_id = self.sharding_rank + 1
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.dp_rank
]
self.dp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \
"global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \
"global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format(
self.global_word_size,
self.sharding_group_size,
self.dp_group_size)
logging.info("Using Sharing&DP mode !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
return
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册