未验证 提交 69c874fd 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training (#31884)

上级 43367e4b
......@@ -29,9 +29,14 @@ message RecomputeConfig {
}
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
}
message AMPConfig {
......
......@@ -59,6 +59,7 @@ class AMPOptimizer(MetaOptimizerBase):
is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale
# FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior that to disable is_distributed.
is_distributed = False
self.wrapped_opt._set_distributed(is_distributed)
......
......@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
......@@ -103,6 +103,7 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
......@@ -111,12 +112,24 @@ class FP16Utils(object):
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp \
check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
......@@ -128,32 +141,30 @@ class FP16Utils(object):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
# this allreduce communication should not overlap with calc
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_fp32},
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
......
......@@ -16,14 +16,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def __init__(self, mp_ring_id):
self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard):
def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
......@@ -31,6 +31,7 @@ class GradientClipHelper(object):
"""
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
......@@ -44,6 +45,8 @@ class GradientClipHelper(object):
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op:
deperate_op_idx.add(idx)
......@@ -65,31 +68,48 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
# this allreduce should not overlap with calc and should be scheduled in calc stream
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
'ring_id': self.mp_ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(set(
[param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale \
checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
......
......@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][
"broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[
var_name]["broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
......@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
......@@ -78,43 +82,48 @@ def check_broadcast(block):
return
def check_allreduce_sum(block, shard, dp_ring_id=-1):
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
"""
the op order should be:
grad:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum_sharding
- 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
# sharding use both allreduce and reduce to sync grad
if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if ring_id != sharding_ring_id:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
......@@ -128,38 +137,41 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
# check sharding allreduce and reduce but skip megatron allreduce
elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == sharding_ring_id:
assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(
var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
......@@ -181,6 +193,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
......@@ -325,6 +340,27 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward
})
return
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......@@ -428,7 +464,7 @@ def comm_analyse(main_program):
count))
def add_sync_comm(program, dist_strategy):
def add_sync_comm(program, sharding_ring_id):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
......@@ -438,6 +474,7 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
block = program.global_block()
not_sync_vars = set([])
for op in block.ops:
......@@ -448,15 +485,14 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name)
if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num):
block.append_op(
type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)},
outputs={'Out': list(not_sync_vars)},
attrs={
'ring_id': nccl_id,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)},
outputs={'Out': list(not_sync_vars)},
attrs={
'ring_id': sharding_ring_id,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
})
return
......@@ -468,7 +504,7 @@ def save_persistables(exe, dirname, main_program, filename=None):
"""
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
checks = [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
......@@ -479,12 +515,18 @@ def save_persistables(exe, dirname, main_program, filename=None):
return True
return False
def is_gradient_merge_vars(var):
# NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule
return var.name.endswith("@GradiantMerge")
def is_trainable(var):
return isinstance(var,
paddle.fluid.framework.Parameter) and var.trainable
def sharding_predicate(var):
return is_trainable(var) or is_opt_vars(var)
return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(
var)
if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
paddle.fluid.io.save_persistables(
......@@ -498,3 +540,42 @@ def save_persistables(exe, dirname, main_program, filename=None):
filename=None)
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def append_naive_sync(block, sync_var, ring_id):
# NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
# sync within global
block.append_op(
type="fill_constant",
outputs={"Out": sync_var},
attrs={
"shape": sync_var.shape,
"dtype": sync_var.dtype,
"value": int(1),
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': sync_var},
outputs={'Out': sync_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
......@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[
_logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0]))
updated_min_idx = idx_
......@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.debug(
_logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
......@@ -784,7 +784,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0
pre_segment_end_idx = -1
while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
......@@ -797,6 +796,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1
......@@ -806,15 +808,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory
......@@ -825,9 +827,7 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
_logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory
......@@ -888,6 +888,17 @@ def _append_backward_ops_with_checkpoints_(
continue
if name not in var_name_dict:
var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
ref_var = block.program.global_block().var(name)
block.create_var(
name=var_name_dict[name],
shape=ref_var.shape,
dtype=ref_var.dtype,
type=ref_var.type,
persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory)
......
......@@ -59,7 +59,11 @@ def runtime_main():
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"sharding_degree": 2,
}
optimizer = paddle.fluid.optimizer.Momentum(
learning_rate=0.01, momentum=0.9)
......
......@@ -146,7 +146,11 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
elif name == "sharding":
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"sharding_degree": 2,
}
elif name == "recompute-offload":
strategy.recompute = True
strategy.recompute_configs = {
......
......@@ -1125,6 +1125,7 @@ class TestDistBase(unittest.TestCase):
if check_error_log:
print("outs[0]:", outs[0])
print("outs[1]:", outs[1])
return pickle.loads(outs[0]), pickle.loads(outs[1])
def _run_pipeline(self, model, envs, check_error_log, log_name):
......
......@@ -45,6 +45,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
......@@ -55,9 +56,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'momentum',
'momentum', 'momentum'
])
def test_sharding_amp_optimizer(self):
......@@ -82,6 +83,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
]))
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast',
......@@ -94,11 +96,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream',
'c_allreduce_max', 'c_sync_comm_stream', 'cast',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'momentum', 'momentum', 'momentum'
])
......@@ -124,6 +125,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
......@@ -134,10 +136,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul',
'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad',
'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_comm_stream',
'momentum', 'momentum', 'momentum'
'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
])
def test_sharding_amp_recompute_optimizer(self):
......@@ -167,29 +168,27 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
]))
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'cast', 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh',
'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'fill_constant', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast',
'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast',
'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add',
'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad',
'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream',
'c_allreduce_max', 'c_sync_comm_stream', 'cast',
'update_loss_scaling', 'momentum', 'momentum', 'momentum'
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'cast',
'cast', 'cast', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'momentum', 'momentum'
])
def test_sharding_weight_decay(self):
......@@ -227,10 +226,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'scale', 'sum', 'scale', 'sum', 'scale',
'sum', 'momentum', 'momentum', 'momentum'
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'scale',
'sum', 'scale', 'sum', 'scale', 'sum', 'momentum', 'momentum',
'momentum'
])
def test_sharding_gradient_clip(self):
......@@ -253,6 +252,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
......@@ -263,14 +263,12 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'square', 'reduce_sum', 'square',
'reduce_sum', 'square', 'reduce_sum', 'sum', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_sync_comm_stream', 'sqrt', 'fill_constant',
'elementwise_max', 'elementwise_div', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
'momentum'
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square',
'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum'
])
def test_sharding_clone_for_test(self):
......@@ -281,7 +279,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
sharding.utils.comm_analyse(train_prog)
test_prog = train_prog.clone(for_test=True)
sharding.utils.add_sync_comm(test_prog, strategy)
# assume sharding_ring_id = 1
sharding.utils.add_sync_comm(test_prog, 1)
ops = [op.type for op in test_prog.global_block().ops]
self.assertEqual(ops, [
......@@ -293,5 +292,200 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
])
class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "3"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004"
def test_sharding_with_mp(self):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": False,
"gradient_merge_acc_step": 1,
"mp_degree": 2
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# should has ring id for MP
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(0, created_ring_ids)
# check correctness of MP group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_sharding_hybrid_dp(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check ring id for outter dp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
# check loss scale for sharding hybrid dp
scale_ = -1
for op in main_prog_ops:
if op.type == "scale":
scale_ = float(op.desc.attr("scale"))
self.assertEqual(scale_, 0.25)
# check program (allreudce)
ops = [op.type for op in main_prog_ops]
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
])
def test_sharding_hybrid_dp_gm(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 4,
"mp_degree": 1
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check ring id for outter dp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
# check program
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add',
'elementwise_add', 'increment', 'elementwise_mod', 'equal',
'conditional_block'
])
self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
'scale', 'scale', 'momentum', 'momentum', 'momentum',
'fill_constant', 'fill_constant', 'fill_constant'
])
# # check loss scale for gradient merge
scale_ = -1
for op in train_prog.blocks[2].ops:
if op.type == "scale":
scale_ = float(op.desc.attr("scale"))
self.assertEqual(scale_, 0.25)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册