未验证 提交 69c874fd 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training (#31884)

上级 43367e4b
...@@ -29,9 +29,14 @@ message RecomputeConfig { ...@@ -29,9 +29,14 @@ message RecomputeConfig {
} }
message ShardingConfig { message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ]; optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ]; optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -59,6 +59,7 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -59,6 +59,7 @@ class AMPOptimizer(MetaOptimizerBase):
is_distributed = self.role_maker._worker_num() > 1 is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding: if self.user_defined_strategy.sharding:
# FIXME(wangxi). sharding failed when split check_finite_and_unscale # FIXME(wangxi). sharding failed when split check_finite_and_unscale
# FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior that to disable is_distributed.
is_distributed = False is_distributed = False
self.wrapped_opt._set_distributed(is_distributed) self.wrapped_opt._set_distributed(is_distributed)
......
...@@ -73,7 +73,7 @@ class FP16Utils(object): ...@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
""" """
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding 2. revise amp inifine grad checking for sharding
""" """
# remove cast # remove cast
...@@ -103,6 +103,7 @@ class FP16Utils(object): ...@@ -103,6 +103,7 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding") op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = [] reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
...@@ -111,12 +112,24 @@ class FP16Utils(object): ...@@ -111,12 +112,24 @@ class FP16Utils(object):
"be grads, but {} is not a grad".format(input_name)) "be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name): if shard.has_param(param_name):
reversed_x.append(input_name) reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x) op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp \
check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1: if update_loss_scaling_op_idx == -1:
return return
inf_var = block.var(inf_var_name) inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var( inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
...@@ -128,32 +141,30 @@ class FP16Utils(object): ...@@ -128,32 +141,30 @@ class FP16Utils(object):
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var}, inputs={'X': inf_var},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={ attrs={
"in_dtype": inf_var.dtype, "in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, # this allreduce communication should not overlap with calc
[inf_var_fp32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 2, update_loss_scaling_op_idx + 1,
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': ring_id,
'use_calc_stream': True,
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, OP_ROLE_KEY: OpRole.Optimize
ring_id, [inf_var_fp32]) })
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 2,
type='cast', type='cast',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var_sharding},
attrs={ attrs={
"in_dtype": inf_var_fp32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
......
...@@ -16,14 +16,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole ...@@ -16,14 +16,14 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self, sharding_ring_id): def __init__(self, mp_ring_id):
self.sharding_ring_id = sharding_ring_id self.mp_ring_id = mp_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard): def prune_gradient_clip(self, block, shard, pure_dp_degree=1):
""" """
prune gradient_clip related ops for params that not belong to cur shard prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul prune: square, reduce_sum, elementwise_mul
...@@ -31,6 +31,7 @@ class GradientClipHelper(object): ...@@ -31,6 +31,7 @@ class GradientClipHelper(object):
""" """
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = []
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -44,6 +45,8 @@ class GradientClipHelper(object): ...@@ -44,6 +45,8 @@ class GradientClipHelper(object):
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
elif shard.is_param(param_name):
reversed_x_paramname.append(param_name)
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
...@@ -65,31 +68,48 @@ class GradientClipHelper(object): ...@@ -65,31 +68,48 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs) op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1, # this allreduce should not overlap with calc and should be scheduled in calc stream
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'ring_id': self.sharding_ring_id, 'ring_id': self.mp_ring_id,
OP_ROLE_KEY: OpRole.Optimize 'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
}) })
block._insert_op_without_sync(
idx + 1, # global norm should only be sum within each model parallelism word size when use global group
type='c_sync_calc_stream', if pure_dp_degree > 1:
inputs={'X': sum_res}, block._insert_op_without_sync(
outputs={'Out': sum_res}, idx + 2,
attrs={OP_ROLE_KEY: OpRole.Optimize}) type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(set(
[param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale \
checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
for var_name in deperated_vars: for var_name in deperated_vars:
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
......
...@@ -28,21 +28,24 @@ def check_broadcast(block): ...@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error. sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
""" """
broadcast_vars = {} broadcast_vars = {}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_broadcast": if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
if "@BroadCast" in var_name: var_name = op.desc.input_arg_names()[0]
if var_name in broadcast_vars: if "@BroadCast" in var_name:
raise ValueError("var_name areadly exist: {}" if var_name in broadcast_vars:
"the old pos is {}, the new pos is {}". raise ValueError("var_name areadly exist: {}"
format(var_name, broadcast_vars[var_name][ "the old pos is {}, the new pos is {}".
"broadcast_pos"], idx)) format(var_name, broadcast_vars[
broadcast_vars[var_name] = { var_name]["broadcast_pos"], idx))
"fill_constant_pos": -1, broadcast_vars[var_name] = {
"broadcast_pos": idx, "fill_constant_pos": -1,
} "broadcast_pos": idx,
}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "fill_constant": if op.type == "fill_constant":
...@@ -61,14 +64,15 @@ def check_broadcast(block): ...@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx last_sync_calc_op_idx = idx
continue continue
if op.type == "c_broadcast": if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
if "@BroadCast" in var_name: var_name = op.desc.input_arg_names()[0]
if broadcast_vars[var_name]["fill_constant_pos"] != -1: if "@BroadCast" in var_name:
assert (last_sync_calc_op_idx != -1) if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (broadcast_vars[var_name]["fill_constant_pos"] < assert (last_sync_calc_op_idx != -1)
last_sync_calc_op_idx) assert (broadcast_vars[var_name]["fill_constant_pos"] <
assert (last_sync_calc_op_idx < idx) last_sync_calc_op_idx)
continue assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars: if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1) assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
...@@ -78,43 +82,48 @@ def check_broadcast(block): ...@@ -78,43 +82,48 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block, shard, dp_ring_id=-1): def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
""" """
the op order should be: the op order should be:
grad: grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum_sharding - 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm - 3: sync_comm
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum) - 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
""" """
vars_status = {} vars_status = {}
dp_grads_status = {} dp_grads_status = {}
idx_last_grad_allreduce = -1 idx_last_grad_allreduce = -1
idx_amp_allreduce = -1 idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": # sharding use both allreduce and reduce to sync grad
ring_id = op.desc.attr("ring_id") if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
param = var_name.split("@")[0] ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name) assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)): if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1 vars_status[var_name] = -1
else: else:
dp_grads_status[var_name] = -1 dp_grads_status[var_name] = -1
if ring_id != 0: if ring_id != sharding_ring_id:
assert shard.has_param(param) assert shard.has_param(param)
assert ring_id == dp_ring_id assert ring_id == dp_ring_id
if "sum" in var_name: if "sum" in var_name:
idx_amp_allreduce = idx idx_amp_allreduce = idx
elif "@GRAD": elif "@GRAD":
idx_last_grad_allreduce = idx idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max": if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx idx_gradient_clip_allreduce = idx
...@@ -128,38 +137,41 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -128,38 +137,41 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
if var_name in dp_grads_status and dp_grads_status[ if var_name in dp_grads_status and dp_grads_status[
var_name] == 0: var_name] == 0:
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
# check sharding allreduce and reduce but skip megatron allreduce
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
var_name = op.desc.input_arg_names()[0] if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") var_name = op.desc.input_arg_names()[0]
if ring_id == 0: ring_id = op.desc.attr("ring_id")
if var_name in vars_status: if ring_id == sharding_ring_id:
_status = vars_status[var_name] assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
else: if var_name in vars_status:
_status = dp_grads_status[var_name] _status = vars_status[var_name]
if _status == -1: else:
raise ValueError("{} is not generated, but you are" _status = dp_grads_status[var_name]
"trying to all-reduce it".format(var_name)) if _status == -1:
if _status == 0: raise ValueError("{} is not generated, but you are"
raise ValueError("There should be a sync_calc op " "trying to all-reduce it".format(
"after generate Var: {} and before the" var_name))
"c_allreduce_sum op".format(var_name)) if _status == 0:
assert (_status == 1) raise ValueError("There should be a sync_calc op "
if var_name in vars_status: "after generate Var: {} and before the"
vars_status[var_name] = 2 "c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else: else:
dp_grads_status[var_name] = 2 assert ring_id == dp_ring_id
else: param = var_name.split("@")[0]
assert ring_id == dp_ring_id assert shard.has_param(param)
param = var_name.split("@")[0] assert dp_grads_status[var_name] == 3
assert shard.has_param(param) dp_grads_status[var_name] = 4
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in vars_status: if var_name in vars_status:
assert vars_status[var_name] == 2 assert vars_status[var_name] == 2
...@@ -181,6 +193,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -181,6 +193,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
input_name)) input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status: if input_name in dp_grads_status:
if dp_ring_id == -1: if dp_ring_id == -1:
if dp_grads_status[input_name] != 3: if dp_grads_status[input_name] != 3:
...@@ -325,6 +340,27 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): ...@@ -325,6 +340,27 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
return return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
insert_idx,
type='c_reduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward
})
return
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
...@@ -428,7 +464,7 @@ def comm_analyse(main_program): ...@@ -428,7 +464,7 @@ def comm_analyse(main_program):
count)) count))
def add_sync_comm(program, dist_strategy): def add_sync_comm(program, sharding_ring_id):
""" """
When clone a test prog by clone from the sharding main prog, When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function part of the sync_comm op maybe be pruned by mistake, this function
...@@ -438,6 +474,7 @@ def add_sync_comm(program, dist_strategy): ...@@ -438,6 +474,7 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one #NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future. # comm streams will cause error. should be revise in future.
assert sharding_ring_id >= 0, "sharding_ring_id should larger than zero"
block = program.global_block() block = program.global_block()
not_sync_vars = set([]) not_sync_vars = set([])
for op in block.ops: for op in block.ops:
...@@ -448,15 +485,14 @@ def add_sync_comm(program, dist_strategy): ...@@ -448,15 +485,14 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name) not_sync_vars.remove(input_name)
if not_sync_vars: if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num): block.append_op(
block.append_op( type='c_sync_comm_stream',
type='c_sync_comm_stream', inputs={'X': list(not_sync_vars)},
inputs={'X': list(not_sync_vars)}, outputs={'Out': list(not_sync_vars)},
outputs={'Out': list(not_sync_vars)}, attrs={
attrs={ 'ring_id': sharding_ring_id,
'ring_id': nccl_id, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward })
})
return return
...@@ -468,7 +504,7 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -468,7 +504,7 @@ def save_persistables(exe, dirname, main_program, filename=None):
""" """
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding # now only Momentum and adam are compatible with sharding
checks = [ checks = [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0", "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
...@@ -479,12 +515,18 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -479,12 +515,18 @@ def save_persistables(exe, dirname, main_program, filename=None):
return True return True
return False return False
def is_gradient_merge_vars(var):
# NOTE(JZ-LIANG): to revise save/load logic in framework instead of write this naive rule
return var.name.endswith("@GradiantMerge")
def is_trainable(var): def is_trainable(var):
return isinstance(var, return isinstance(var,
paddle.fluid.framework.Parameter) and var.trainable paddle.fluid.framework.Parameter) and var.trainable
def sharding_predicate(var): def sharding_predicate(var):
return is_trainable(var) or is_opt_vars(var) return is_trainable(var) or is_opt_vars(var) or is_gradient_merge_vars(
var)
if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0: if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
paddle.fluid.io.save_persistables( paddle.fluid.io.save_persistables(
...@@ -498,3 +540,42 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -498,3 +540,42 @@ def save_persistables(exe, dirname, main_program, filename=None):
filename=None) filename=None)
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD']
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def append_naive_sync(block, sync_var, ring_id):
# NOTE (JZ-LIANG) update this to use barrier sync for more elegent logic
# sync within global
block.append_op(
type="fill_constant",
outputs={"Out": sync_var},
attrs={
"shape": sync_var.shape,
"dtype": sync_var.dtype,
"value": int(1),
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': sync_var},
outputs={'Out': sync_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
from paddle.fluid import unique_name, core from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
...@@ -24,7 +24,14 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor ...@@ -24,7 +24,14 @@ from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper impor
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid import layers
import logging import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
from functools import reduce from functools import reduce
__all__ = ["ShardingOptimizer"] __all__ = ["ShardingOptimizer"]
...@@ -39,6 +46,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -39,6 +46,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
"ModelParallelOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -50,6 +58,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -50,6 +58,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# reduced grads to param name # reduced grads to param name
self._reduced_grads_to_param = {} self._reduced_grads_to_param = {}
self._shard = Shard() self._shard = Shard()
self._verbose = False
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self.mp_degree = 1
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
...@@ -64,7 +76,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -64,7 +76,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def _enable_strategy(self, dist_strategy, context): def _enable_strategy(self, dist_strategy, context):
dist_strategy.sharding = True dist_strategy.sharding = True
dist_strategy.sharding_configs = {"fuse_broadcast_MB": 32} dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def minimize_impl(self, def minimize_impl(self,
loss, loss,
...@@ -75,11 +87,53 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -75,11 +87,53 @@ class ShardingOptimizer(MetaOptimizerBase):
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
self._nrings_dp = 1 self._nrings_dp = 1
self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[
"fuse_broadcast_MB"] # parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 1, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
self.hybrid_dp = self.user_defined_strategy.sharding_configs[ self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"] "hybrid_dp"]
self.pp_degree = 1
# dp here is the pure dp as the outest parallelism
self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree //
self.sharding_degree)
assert self.role_maker._worker_num(
) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
# segment
self._sharding_segment_strategy = str(
self.user_defined_strategy.sharding_configs[
"sharding_segment_strategy"])
if self._sharding_segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = self.user_defined_strategy.sharding_configs[
"segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !"
elif self._sharding_segment_strategy == "segment_anchors":
self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[
"segment_anchors"]
assert len(self._sharding_segment_anchors
) > 0, "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:]
self._forward_remain_anchors = []
else:
raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format(
str(self._sharding_segment_strategy)))
# gradient merge
self._gradient_merge_acc_step = int(
self.user_defined_strategy.sharding_configs[
"gradient_merge_acc_step"])
self._grad2merged_grad = dict()
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
...@@ -93,8 +147,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -93,8 +147,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program = main_block.program self._main_program = main_block.program
self._startup_program = startup_program self._startup_program = startup_program
# step1: set_up # step0: _init_comm
self._set_up(params_grads) self._init_comm()
# step1: _build_shard
self._build_shard(params_grads)
# step2: split_program # step2: split_program
self._split_program(main_block) self._split_program(main_block)
...@@ -104,75 +161,166 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -104,75 +161,166 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad # step4: scale the loss by the num of dp degree
insert_scale_loss_grad_ops( # sharding is also a senario of dp
main_block, scale=1.0 / self.role_maker._worker_num()) scale_ = self.dp_degree * self.sharding_degree
if scale_ > 1:
insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_)
main_block._sync_with_cpp() main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block # step5: remove unneeded ops and vars from block
self._prune_main_program(main_block) self._prune_main_program(main_block)
self._prune_startup_program(startup_block) self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
# check op dependecy # step6: optional gradient merge
check_broadcast(main_block) if self._gradient_merge_acc_step > 1:
check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._sharding_gradient_merge(main_block)
# # check op dependecy
# FIXME (JZ-LIANG) enable checking in future.
# check_broadcast(main_block)
# check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
def _set_up(self, params_grads): def _init_comm(self):
# step 1: initialize nccl
self.global_word_size = self.role_maker._worker_num()
self.global_rank = self.role_maker._worker_index()
self.endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings_sharding)
# config sharding & dp groups # config sharding & dp groups
self._init_comm() self._build_group()
# sharding
startup_block = self._startup_program.global_block()
self.startup_prog_sync_var = startup_block.create_var(
name="startup_prog_sync_var",
shape=[1],
dtype=core.VarDesc.VarType.INT32,
persistable=False)
# global
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program,
self.sharding_group_endpoints, self.sharding_rank, self.current_endpoint,
self.sharding_ring_id, True) self.global_endpoints,
self.global_rank,
self.global_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# mp
if self.mp_degree > 1:
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.mp_group_endpoints,
self.mp_rank,
self.mp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding
if self.sharding_degree > 1:
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.sharding_group_endpoints,
self.sharding_rank,
self.sharding_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# dp # dp
if self.hybrid_dp: if self.dp_degree > 1:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) self.current_endpoint,
self.dp_group_endpoints,
self.dp_rank,
self.dp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
def _build_shard(self, params_grads):
# step 2: split params # step 2: split params
self._params = set([x[0].name for x in params_grads]) self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, self.sharding_rank, self._shard.setup(params_grads, self.sharding_rank,
self.sharding_group_size) self.sharding_degree)
# step 3: get broadcast vars # step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params( self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block()) self._main_program.global_block())
def _wait(self, ): def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints() endpoints = self.global_endpoints[:]
current_endpoint = endpoints[self.role_maker._worker_index()] current_endpoint = endpoints[self.global_rank]
if self.role_maker._worker_index() == 0: if self.global_rank == 0:
self._collective_helper._wait(current_endpoint, endpoints) self._collective_helper._wait(current_endpoint, endpoints)
def collect_segment(self, segment, op_idx, block):
segment._start_idx = op_idx + 1
self._segments.insert(0, segment)
new_segment = ProgramSegment(block)
new_segment._end_idx = op_idx + 1
return new_segment
def _split_program(self, block): def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))): for op_idx, op in reversed(list(enumerate(block.ops))):
if int(op.attr('op_role')) != int(OpRole.Optimize): if int(op.attr('op_role')) != int(OpRole.Optimize):
last_backward_op_idx = op_idx + 1 last_backward_op_idx = op_idx + 1
break break
var2broadcast_time = dict()
segment = ProgramSegment(block) segment = ProgramSegment(block)
segment._end_idx = last_backward_op_idx segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)): for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx] op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize)) assert (int(op.attr('op_role')) != int(OpRole.Optimize))
if segment._param_mem >= self._fuse_broadcast_MB: if self._sharding_segment_strategy == "segment_broadcast_MB":
segment._start_idx = op_idx + 1 if segment._param_mem >= self._broadcast_MB:
self._segments.insert(0, segment) segment = self.collect_segment(segment, op_idx, block)
segment = ProgramSegment(block)
segment._end_idx = op_idx + 1 elif self._sharding_segment_strategy == "segment_anchors":
if int(op.attr('op_role')) == int(OpRole.Backward):
for input_name in op.desc.input_arg_names():
# NOTE (JZ-LIANG) naive rule to support amp, if amp change, should modify here accordingly
if self.user_defined_strategy.amp:
if ".cast_fp16@GRAD" not in input_name:
continue
else:
input_name = input_name[:input_name.find(
".cast_fp16@GRAD")]
if input_name in self._backward_remain_anchors:
segment = self.collect_segment(segment, op_idx,
block)
assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format(
input_name)
self._backward_remain_anchors.remove(input_name)
self._forward_remain_anchors.append(input_name)
elif int(op.attr('op_role')) == int(OpRole.Forward):
for output_name in op.desc.output_arg_names():
if output_name in self._forward_remain_anchors:
segment = self.collect_segment(segment, op_idx,
block)
self._forward_remain_anchors.remove(output_name)
# find broadcast vars # find broadcast vars
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
...@@ -190,6 +338,21 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -190,6 +338,21 @@ class ShardingOptimizer(MetaOptimizerBase):
broadcast_var_name = unique_name.generate(input_name + broadcast_var_name = unique_name.generate(input_name +
"@BroadCast") "@BroadCast")
segment._fill_constant_vars.append(broadcast_var_name) segment._fill_constant_vars.append(broadcast_var_name)
# (JZ-LIANG) should use Param base name ?
broadcast_var_base_name = input_name
if "subprog" in broadcast_var_base_name:
# remove suffix
broadcast_var_base_name = broadcast_var_base_name[:
broadcast_var_base_name.
find(
".subprog"
)]
var2broadcast_time[
broadcast_var_base_name] = var2broadcast_time.get(
broadcast_var_base_name, 0) + 1
segment._param2broadcast[input_name] = broadcast_var_name segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append((broadcast_var_name, segment._broadcast_vars.append((broadcast_var_name,
self._shard.device(input_name))) self._shard.device(input_name)))
...@@ -219,6 +382,30 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -219,6 +382,30 @@ class ShardingOptimizer(MetaOptimizerBase):
if segment._param_mem > 0: if segment._param_mem > 0:
segment._start_idx = 0 segment._start_idx = 0
self._segments.insert(0, segment) self._segments.insert(0, segment)
if self._sharding_segment_strategy == "segment_anchors":
assert len(
self._forward_remain_anchors) == 0, "remain anchors {}".format(
self._forward_remain_anchors)
assert len(
self._backward_remain_anchors) == 0, "remain anchors {}".format(
self._backward_remain_anchors)
if self._verbose:
for varname in sorted(
var2broadcast_time, key=var2broadcast_time.get,
reverse=True):
logging.info("Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname))
for idx_ in range(len(self._segments)):
logging.info("segment [{}] :".format(idx_))
logging.info("start op: [{}] [{}]".format(block.ops[
self._segments[idx_]._start_idx].desc.type(), block.ops[
self._segments[idx_]._start_idx].desc.input_arg_names(
)))
logging.info("end op: [{}] [{}]".format(block.ops[
self._segments[idx_]._end_idx].desc.type(), block.ops[
self._segments[idx_]._end_idx].desc.input_arg_names()))
return return
def _prune_main_program(self, block): def _prune_main_program(self, block):
...@@ -234,10 +421,21 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -234,10 +421,21 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, self._shard)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id) self.global_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id) # clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
gradientclip_helper.prune_gradient_clip(block, self._shard) if self.mp_degree * self.pp_degree == 1:
# separate the sharding-hybrid senario to keep the accuracy
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
gradientclip_helper.prune_gradient_clip(
block, self._shard, pure_dp_degree=1)
else:
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper.prune_gradient_clip(
block, self._shard, pure_dp_degree=self.dp_degree)
# build prog deps # build prog deps
reduced_grads = [] reduced_grads = []
...@@ -307,7 +505,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -307,7 +505,8 @@ class ShardingOptimizer(MetaOptimizerBase):
def _add_broadcast_allreduce(self, block): def _add_broadcast_allreduce(self, block):
""" """
_add_broadcast_allreduce add broadcast allreduce op
if enable gradient_merge, insert related ops
""" """
if len(self._segments) < 1: if len(self._segments) < 1:
return return
...@@ -315,17 +514,27 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -315,17 +514,27 @@ class ShardingOptimizer(MetaOptimizerBase):
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1] shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars) ._allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1: if self._gradient_merge_acc_step <= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx, if self.hybrid_dp and len(shard_allredue_vars) >= 1:
self.dp_ring_id, shard_allredue_vars) insert_sync_comm_ops(block, self._segments[-1]._end_idx,
insert_allreduce_ops(block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars)
self.dp_ring_id, shard_allredue_vars) insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
# gradient merge
else:
self.create_persistable_gradients_and_insert_merge_ops(
block,
self._startup_program.global_block(),
self._segments[-1]._end_idx, shard_allredue_vars,
self._shard)
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, # allreduce --> reduce
self.sharding_ring_id, insert_reduce_ops(block, self._segments[-1]._end_idx,
self._segments[-1]._allreduce_vars) self.sharding_ring_id,
self._segments[-1]._allreduce_vars, self._shard)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = self._segments[
...@@ -364,19 +573,31 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -364,19 +573,31 @@ class ShardingOptimizer(MetaOptimizerBase):
# step2: add Sync ops # step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars) shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id,
shard_allredue_vars)
if self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx,
self.dp_ring_id, shard_allredue_vars)
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id,
broad_cast_vars)
else:
comm_dep_vars = allreduce_vars + [
x[0] for x in broadcast_vars
]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id,
comm_dep_vars)
# gradient merge
else:
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, broad_cast_vars) self.sharding_ring_id, broad_cast_vars)
else:
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, comm_dep_vars)
calc_dep_vars = fill_constant_vars + [ calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items() k for k, v in cast_ops.items()
...@@ -394,18 +615,32 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -394,18 +615,32 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_cast_ops(block, segment._end_idx, cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops # step5: add broadcast ops
# gradient merge
if self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops(
block,
self._startup_program.global_block(), segment._start_idx,
shard_allredue_vars, self._shard)
insert_broadcast_ops(block, segment._start_idx, insert_broadcast_ops(block, segment._start_idx,
self.sharding_ring_id, broadcast_vars) self.sharding_ring_id, broadcast_vars)
# step6: add all_reduce ops # step6: add all_reduce ops
# dp # dp
if self.hybrid_dp and len(shard_allredue_vars) >= 1: if self._gradient_merge_acc_step <= 1:
insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id, if self.hybrid_dp and len(shard_allredue_vars) >= 1:
shard_allredue_vars) insert_allreduce_ops(block, segment._start_idx,
self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# gradient merge
else:
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# sharding # sharding
insert_allreduce_ops(block, segment._start_idx, # allreduce --> reduce
self.sharding_ring_id, allreduce_vars) insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id,
allreduce_vars, self._shard)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -456,59 +691,440 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -456,59 +691,440 @@ class ShardingOptimizer(MetaOptimizerBase):
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
def _init_comm(self): def _build_group(self):
"""
if self.hybrid_dp: pre-assign ring ids
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ mp: 0
"sharding_group_size"] sharding: 1
self.sharding_ring_id = 0 pure-dp: 2
self.sharding_rank = self.global_rank % self.sharding_group_size global: 3
pp: >= 20
self.dp_group_size = self.global_word_size // self.sharding_group_size if one parallelism is not enable: -1
self.dp_rank = self.global_rank // self.sharding_group_size and only support parallelism hierarchy: mp --> sharding --> pp --> dp
self.dp_ring_id = self.sharding_rank + 1 """
# step 1: initialize nccl
self.sharding_group_endpoints = [ self.global_word_size = self.role_maker._worker_num()
ep for idx, ep in enumerate(self.endpoints) self.global_rank = self.role_maker._worker_index()
if (idx // self.sharding_group_size) == self.dp_rank self.global_endpoints = self.role_maker._get_trainer_endpoints()
] self.current_endpoint = self.global_endpoints[self.global_rank]
self.dp_group_endpoints = [ self._collective_helper = CollectiveHelper(
ep for idx, ep in enumerate(self.endpoints) self.role_maker, nrings=self._nrings_sharding)
if (idx % self.sharding_group_size) == self.sharding_rank assert self.global_word_size % self.mp_degree == 0, \
"global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree)
assert self.global_word_size % self.sharding_degree == 0, \
"global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree)
assert self.global_word_size % self.pp_degree == 0, \
"global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree)
assert self.global_word_size % self.dp_degree == 0, \
"global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree)
# mp group
if self.mp_degree > 1:
self.mp_ring_id = 0
self.mp_rank = self.global_rank % self.mp_degree
self.mp_group_id = self.global_rank // self.mp_degree
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
if idx // self.mp_degree == self.mp_group_id
] ]
assert self.global_word_size > self.sharding_group_size, \ assert self.current_endpoint in self.mp_group_endpoints
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) assert len(
assert self.global_word_size % self.sharding_group_size == 0, \ self.mp_group_endpoints
"global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) ) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format(
assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \ len(self.mp_group_endpoints), self.mp_degree)
"global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format( else:
self.global_word_size, self.mp_degree = 1
self.sharding_group_size, self.mp_ring_id = -1
self.dp_group_size) self.mp_rank = -1
self.mp_group_id = -1
logging.info("Using Sharing&DP mode !") self.mp_group_endpoints = []
# sharding
if self.sharding_degree > 1:
self.sharding_ring_id = 1
self.sharding_rank = (self.global_rank //
self.mp_degree) % self.sharding_degree
self.sharding_group_id = self.global_rank // (self.mp_degree *
self.sharding_degree)
# mp + sharding + ...
if self.mp_degree > 1:
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree)) == self.
sharding_group_id and idx % self.mp_degree == self.mp_rank
]
# sharding + ...
else:
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints)
if (idx // (self.mp_degree * self.sharding_degree)
) == self.sharding_group_id
]
assert self.current_endpoint in self.sharding_group_endpoints
else:
self.sharding_degree = 1
self.sharding_ring_id = -1
self.sharding_rank = -1
self.sharding_group_id = -1
self.sharding_group_endpoints = []
# outter-pure-dp group
# NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism
# e.g. mp-sharding-pp-dp
# sharding-hybrid-dp as one senario of outter-pure-dp
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree, self.sharding_degree, self.pp_degree,
self.dp_degree, self.global_word_size)
if self.dp_degree > 1:
self.dp_ring_id = 2
self.dp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree * self.pp_degree)
dp_first_rank_idx = self.global_rank % (
self.sharding_degree * self.mp_degree * self.pp_degree)
dp_offset = (self.sharding_degree * self.mp_degree * self.pp_degree)
self.dp_group_endpoints = []
for i in range(self.dp_degree):
self.dp_group_endpoints.append(self.global_endpoints[
dp_first_rank_idx + dp_offset * i])
assert self.current_endpoint in self.dp_group_endpoints
logging.info("Hybrid DP mode turn on !")
else: else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
self.dp_group_size = None self.dp_group_endpoints = []
self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !") # global group
self.global_ring_id = 3
logging.info("global word size: {}".format(self.global_word_size)) logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank)) logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size)) logging.info("global endpoints: {}".format(self.global_endpoints))
logging.info("global ring id: {}".format(self.global_ring_id))
logging.info("#####" * 6)
logging.info("mp group size: {}".format(self.mp_degree))
logging.info("mp rank: {}".format(self.mp_rank))
logging.info("mp group id: {}".format(self.mp_group_id))
logging.info("mp group endpoints: {}".format(self.mp_group_endpoints))
logging.info("mp ring id: {}".format(self.mp_ring_id))
logging.info("#####" * 6)
logging.info("sharding group size: {}".format(self.sharding_degree))
logging.info("sharding rank: {}".format(self.sharding_rank)) logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size)) logging.info("sharding group id: {}".format(self.sharding_group_id))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("sharding group endpoints: {}".format( logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints)) self.sharding_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) logging.info("sharding ring id: {}".format(self.sharding_ring_id))
logging.info("global word endpoints: {}".format(self.endpoints)) logging.info("#####" * 6)
logging.info("outter pure dp group size: {}".format(self.dp_degree))
logging.info("outter pure dp rank: {}".format(self.dp_rank))
logging.info("outter pure dp group endpoints: {}".format(
self.dp_group_endpoints))
logging.info("outter pure dp ring id: {}".format(self.dp_ring_id))
logging.info("#####" * 6)
return return
def _initialization_broadcast(self, startup_prog):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
params.append(param)
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
# sync within global group
append_naive_sync(block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops(
self, main_block, startup_block, insert_idx, grad_names, shard):
for grad_name in grad_names:
assert get_grad_device(
grad_name, shard
) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format(
grad_name)
persistable_grad_name = grad_name + '@GradiantMerge'
assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(
grad_name)
self._grad2merged_grad[grad_name] = persistable_grad_name
grad_var = main_block.var(grad_name)
# create var
gradient_merge_var = main_block.create_var(
name=persistable_grad_name,
shape=grad_var.shape,
dtype=grad_var.dtype,
persistable=True)
startup_gradient_merge_var = startup_block.create_var(
name=persistable_grad_name,
shape=grad_var.shape,
dtype=grad_var.dtype,
persistable=True)
# merge gradient
main_block._insert_op_without_sync(
insert_idx,
type="elementwise_add",
inputs={'X': grad_name,
'Y': gradient_merge_var},
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward
})
# startup initialization
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": grad_var.shape,
"dtype": grad_var.dtype,
"value": float(0),
})
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
def _create_gm_cond(self, main_block):
# Add const var
acc_step_var = layers.create_global_var(
name="gradient_merge_acc_step",
shape=[1],
value=int(self._gradient_merge_acc_step),
dtype='int32',
persistable=True,
force_cpu=True)
zero_var = layers.create_global_var(
name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
# Add step var & cond var
current_step_var = layers.create_global_var(
name="gradient_merge_current_step",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
cond_var = layers.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=False,
force_cpu=True)
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
main_block.append_op(
type='increment',
inputs={'X': [current_step_var]},
outputs={'Out': [current_step_var]},
attrs={'step': float(1),
OP_ROLE_KEY: OpRole.Optimize})
main_block.append_op(
type='elementwise_mod',
inputs={'X': current_step_var,
'Y': acc_step_var},
outputs={'Out': current_step_var},
attrs={
'axis': -1,
OP_ROLE_KEY: OpRole.Optimize,
'use_mkldnn': False
})
# cond_var = (step_var == 0)
main_block.append_op(
type='equal',
inputs={'X': current_step_var,
'Y': zero_var},
outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Optimize})
# paddle.static.Print(current_step_var, message="in FWBW last conditional")
return cond_var
def _true_apply_gradient(self):
"""
allreduce grad@gradientmerge in dp group
grad@gradientmerge / acc_step
re-create all optimize ops of origin main block and rename them
cast(backward)
amp
clip
opt
# fill constant grad@gradientmerge
"""
# current conditional block
main_block = self._main_program.global_block()
cur_block_idx = self._main_program.current_block_idx
cur_block = self._main_program.current_block()
self.cond_block = self._main_program.current_block()
# cur_block's forward_block & backward_block is itself
cur_block._set_forward_block_idx(cur_block_idx)
# allreduce grad@gradientmerge
if self.hybrid_dp:
assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode"
for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(
type='c_allreduce_sum',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'ring_id': self.dp_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
# grad@gradientmerge / acc_step
for grad, merged_grad in self._grad2merged_grad.items():
# grad /= k_steps
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(
type='scale',
inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var},
attrs={
'scale': 1.0 / float(self._gradient_merge_acc_step),
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# re-create optimize ops
already_moved_var_names = []
for op_desc in self.original_optimize_ops_desc:
new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc)
for input_name in new_op_desc.input_arg_names():
if input_name in self._grad2merged_grad:
new_op_desc._rename_input(
input_name, self._grad2merged_grad[input_name])
for output_name in new_op_desc.output_arg_names():
if output_name in self._grad2merged_grad:
new_op_desc._rename_output(
output_name, self._grad2merged_grad[output_name])
# move non temp optimize vars from block0 to cond block
if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys(
):
var_ = self._main_program.global_block().var(output_name)
if not var_.persistable:
# move
name_ = var_.name
shape_ = var_.shape
type_ = var_.dtype
self._main_program.global_block()._remove_var(
var_.name, sync=False)
self.cond_block.create_var(
name=name_,
shape=shape_,
dtype=type_,
persistable=False)
already_moved_var_names.append(name_)
self._main_program.global_block()._sync_with_cpp()
cur_block._sync_with_cpp()
# fill zero to grad@gradientmerge
for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(
type='fill_constant',
outputs={'Out': merged_grad_var},
attrs={
"shape": merged_grad_var.shape,
"dtype": merged_grad_var.dtype,
"value": float(0),
OP_ROLE_KEY: OpRole.Optimize
})
# lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
def _sharding_gradient_merge(self, main_block):
"""
copy all optimize ops in origin main block
remove all optimize ops in origin main block
create cond block
"""
# copy original optimize ops to temp ops desc list
# remove them from block 0
tmp_copy_block = self._main_program._create_block()
self.original_optimize_ops_desc = []
for op_idx, op in reversed(list(enumerate(main_block.ops))):
if int(op.attr('op_role')) != int(OpRole.Optimize):
continue
else:
tmp_op_desc = tmp_copy_block.desc.append_op()
tmp_op_desc.copy_from(op.desc)
self.original_optimize_ops_desc.append(tmp_op_desc)
main_block._remove_op(op_idx, sync=False)
tmp_copy_block._sync_with_cpp()
self.original_optimize_ops_desc = list(
reversed(self.original_optimize_ops_desc))
# back to block 0
self._main_program._rollback()
# create cond vars and ops at the end of block 0
cond = self._create_gm_cond(main_block)
# create cond block
cond_block = self._main_program._create_block()
self._true_apply_gradient()
# back to block 0
self._main_program._rollback()
# cond op
step_scope = self._main_program.global_block().create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
conditional_block_op = self._main_program.global_block().append_op(
type='conditional_block',
inputs={
'Cond': cond,
'Input': [],
},
outputs={'Out': [],
'Scope': [step_scope]},
attrs={
'sub_block': cond_block,
'is_scalar_condition': True,
})
...@@ -115,7 +115,7 @@ class ProgramStats(object): ...@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx updated_min_idx = min_idx
while idx_ > pre_segment_end_idx: while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]): if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[ _logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0])) 0]))
updated_min_idx = idx_ updated_min_idx = idx_
...@@ -155,7 +155,7 @@ class ProgramStats(object): ...@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = [] sorted_checkpoints = []
for name in checkpoints_name: for name in checkpoints_name:
if name not in self.var_op_deps: if name not in self.var_op_deps:
_logger.debug( _logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name) % name)
elif self.var_op_deps[name]["var_as_output_ops"] == []: elif self.var_op_deps[name]["var_as_output_ops"] == []:
...@@ -784,7 +784,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -784,7 +784,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0 start_idx = 0
pre_segment_end_idx = -1 pre_segment_end_idx = -1
while True: while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1: if start_idx >= len(checkpoints_name) - 1:
break break
# min_idx: checkpoint_1' s input op # min_idx: checkpoint_1' s input op
...@@ -797,6 +796,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -797,6 +796,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start( min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1 start_idx += 1
...@@ -806,15 +808,15 @@ def _append_backward_ops_with_checkpoints_( ...@@ -806,15 +808,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments): for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory # 2) go through all forward ops and induct all variables that will be hold in memory
...@@ -825,9 +827,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -825,9 +827,7 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name) cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory # b. output of seed op should be kept in memory
...@@ -888,6 +888,17 @@ def _append_backward_ops_with_checkpoints_( ...@@ -888,6 +888,17 @@ def _append_backward_ops_with_checkpoints_(
continue continue
if name not in var_name_dict: if name not in var_name_dict:
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
ref_var = block.program.global_block().var(name)
block.create_var(
name=var_name_dict[name],
shape=ref_var.shape,
dtype=ref_var.dtype,
type=ref_var.type,
persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops # 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory) vars_in_memory)
......
...@@ -59,7 +59,11 @@ def runtime_main(): ...@@ -59,7 +59,11 @@ def runtime_main():
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"sharding_degree": 2,
}
optimizer = paddle.fluid.optimizer.Momentum( optimizer = paddle.fluid.optimizer.Momentum(
learning_rate=0.01, momentum=0.9) learning_rate=0.01, momentum=0.9)
......
...@@ -146,7 +146,11 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -146,7 +146,11 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True} strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
elif name == "sharding": elif name == "sharding":
strategy.sharding = True strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"sharding_degree": 2,
}
elif name == "recompute-offload": elif name == "recompute-offload":
strategy.recompute = True strategy.recompute = True
strategy.recompute_configs = { strategy.recompute_configs = {
......
...@@ -1125,6 +1125,7 @@ class TestDistBase(unittest.TestCase): ...@@ -1125,6 +1125,7 @@ class TestDistBase(unittest.TestCase):
if check_error_log: if check_error_log:
print("outs[0]:", outs[0]) print("outs[0]:", outs[0])
print("outs[1]:", outs[1]) print("outs[1]:", outs[1])
return pickle.loads(outs[0]), pickle.loads(outs[1]) return pickle.loads(outs[0]), pickle.loads(outs[1])
def _run_pipeline(self, model, envs, check_error_log, log_name): def _run_pipeline(self, model, envs, check_error_log, log_name):
......
...@@ -45,6 +45,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -45,6 +45,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
])) ]))
self.assertEqual(ops, [ self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
...@@ -55,9 +56,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -55,9 +56,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'momentum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' 'momentum', 'momentum'
]) ])
def test_sharding_amp_optimizer(self): def test_sharding_amp_optimizer(self):
...@@ -82,6 +83,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -82,6 +83,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
])) ]))
self.assertEqual(ops, [ self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast',
...@@ -94,11 +96,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -94,11 +96,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_sync_comm_stream', 'cast', 'cast', 'cast', 'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', 'check_finite_and_unscale', 'cast', 'c_allreduce_max', 'cast',
'c_allreduce_max', 'c_sync_comm_stream', 'cast',
'update_loss_scaling', 'momentum', 'momentum', 'momentum' 'update_loss_scaling', 'momentum', 'momentum', 'momentum'
]) ])
...@@ -124,6 +125,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -124,6 +125,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
])) ]))
self.assertEqual(ops, [ self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
...@@ -134,10 +136,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -134,10 +136,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul',
'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad',
'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_comm_stream', 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
'momentum', 'momentum', 'momentum'
]) ])
def test_sharding_amp_recompute_optimizer(self): def test_sharding_amp_recompute_optimizer(self):
...@@ -167,29 +168,27 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -167,29 +168,27 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
])) ]))
self.assertEqual(ops, [ self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', 'cast', 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'mean', 'elementwise_mul', 'fill_constant', 'scale', 'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad', 'elementwise_add', 'cast', 'tanh_grad', 'cast',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'elementwise_add_grad', 'mul_grad', 'cast', 'mul',
'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'cast',
'c_sync_comm_stream', 'cast', 'cast', 'cast', 'cast', 'cast', 'check_finite_and_unscale', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'c_allreduce_max', 'c_sync_comm_stream', 'cast', 'momentum', 'momentum'
'update_loss_scaling', 'momentum', 'momentum', 'momentum'
]) ])
def test_sharding_weight_decay(self): def test_sharding_weight_decay(self):
...@@ -227,10 +226,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -227,10 +226,10 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'scale',
'c_sync_comm_stream', 'scale', 'sum', 'scale', 'sum', 'scale', 'sum', 'scale', 'sum', 'scale', 'sum', 'momentum', 'momentum',
'sum', 'momentum', 'momentum', 'momentum' 'momentum'
]) ])
def test_sharding_gradient_clip(self): def test_sharding_gradient_clip(self):
...@@ -253,6 +252,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -253,6 +252,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
])) ]))
self.assertEqual(ops, [ self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
...@@ -263,14 +263,12 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -263,14 +263,12 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'square',
'c_sync_comm_stream', 'square', 'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum',
'reduce_sum', 'square', 'reduce_sum', 'sum', 'c_sync_calc_stream', 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'c_allreduce_sum', 'c_sync_comm_stream', 'sqrt', 'fill_constant', 'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_max', 'elementwise_div', 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', 'momentum'
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
'momentum'
]) ])
def test_sharding_clone_for_test(self): def test_sharding_clone_for_test(self):
...@@ -281,7 +279,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -281,7 +279,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.optimizer(avg_cost, strategy, train_prog, startup_prog) self.optimizer(avg_cost, strategy, train_prog, startup_prog)
sharding.utils.comm_analyse(train_prog) sharding.utils.comm_analyse(train_prog)
test_prog = train_prog.clone(for_test=True) test_prog = train_prog.clone(for_test=True)
sharding.utils.add_sync_comm(test_prog, strategy) # assume sharding_ring_id = 1
sharding.utils.add_sync_comm(test_prog, 1)
ops = [op.type for op in test_prog.global_block().ops] ops = [op.type for op in test_prog.global_block().ops]
self.assertEqual(ops, [ self.assertEqual(ops, [
...@@ -293,5 +292,200 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -293,5 +292,200 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
]) ])
class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "3"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004"
def test_sharding_with_mp(self):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": False,
"gradient_merge_acc_step": 1,
"mp_degree": 2
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# should has ring id for MP
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(0, created_ring_ids)
# check correctness of MP group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_sharding_hybrid_dp(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check ring id for outter dp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
# check loss scale for sharding hybrid dp
scale_ = -1
for op in main_prog_ops:
if op.type == "scale":
scale_ = float(op.desc.attr("scale"))
self.assertEqual(scale_, 0.25)
# check program (allreudce)
ops = [op.type for op in main_prog_ops]
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
])
def test_sharding_hybrid_dp_gm(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 4,
"mp_degree": 1
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check ring id for outter dp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of dp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
# check program
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add',
'elementwise_add', 'increment', 'elementwise_mod', 'equal',
'conditional_block'
])
self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
'scale', 'scale', 'momentum', 'momentum', 'momentum',
'fill_constant', 'fill_constant', 'fill_constant'
])
# # check loss scale for gradient merge
scale_ = -1
for op in train_prog.blocks[2].ops:
if op.type == "scale":
scale_ = float(op.desc.attr("scale"))
self.assertEqual(scale_, 0.25)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册