提交 920806db 编写于 作者: S sandyhouse

update

上级 70eb21c5
......@@ -126,6 +126,9 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
......
......@@ -28,21 +28,24 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][
"broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[
var_name]["broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
......@@ -61,14 +64,15 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
......@@ -78,7 +82,7 @@ def check_broadcast(block):
return
def check_allreduce_sum(block, shard, dp_ring_id=-1):
def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
"""
the op order should be:
grad:
......@@ -89,32 +93,36 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
"""
vars_status = {}
dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if ring_id != sharding_ring_id:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
......@@ -130,36 +138,38 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == sharding_ring_id:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(
var_name))
if _status == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (_status == 1)
if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names():
if var_name in vars_status:
assert vars_status[var_name] == 2
......@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward
"""
op_role = block.ops[insert_idx].attr('op_role')
if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward
#if (insert_idx >= len(block.ops)) or (
# op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
# return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward
......@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count))
def add_sync_comm(program, dist_strategy):
def add_sync_comm(program, nccl_ids):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
......@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
assert isinstance(
nccl_ids, list
), "the second argument of this function should be a list of nccl_ids"
block = program.global_block()
not_sync_vars = set([])
for op in block.ops:
......@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name)
if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num):
for nccl_id in nccl_ids:
block.append_op(
type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)},
......@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training.
"""
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
......
......@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
......@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"ModelParallelOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
......@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._reduced_grads_to_param = {}
self._shard = Shard()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self._as_outer_parallelism = False
self._inner_parallelism_size = None
def _can_apply(self):
if not self.role_maker._is_collective:
return False
......@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[
"as_outer_parallelism"]
self._inner_parallelism_size = int(
self.user_defined_strategy.sharding_configs[
"inner_parallelism_size"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
main_program = loss.block.program
main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index(
) // self.user_defined_strategy.sharding_configs[
'sharding_group_size']
main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 1
optimize_ops, params_grads, program_list = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
if self.use_pipeline:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", pp_rank)
main_program = program_list[pp_rank]['program']
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block)
# step1: set_up
self._set_up(params_grads)
......@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker._worker_num())
# grad_scale_coeff = self.role_maker._worker_num()
# if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size)
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
if self.use_pipeline:
# crop ops
for idx, op in reversed(list(enumerate(main_block.ops))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
# out_name = op.output_arg_names[0]
# if not 'GRAD' in out_name: continue
# param_name = out_name.strip("@GRAD")
# #if main_block.has_var(out_name): continue
# if self._shard.has_param(param_name): continue
# main_block._remove_op(idx)
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
param_list = []
for param_name, grad_name in params_grads:
if self._shard.has_param(param_name):
param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer._accumulate_gradients(main_block)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
#grad_var = main_block.vars[grad_name]
#grad_var.persistable = True
#main_block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [grad_var]},
# attrs={
# 'shape': grad_var.shape,
# 'dtype': grad_var.dtype,
# 'value': float(0),
# #self._op_device_key: device,
# # a trick to run this op once per mini-batch
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
main_block._sync_with_cpp()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
......@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase):
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# inner & outer model parallelism
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# dp
if self.hybrid_dp:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
# pp
if self.use_pipeline:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase):
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
#assert (
# reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
......@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.mp_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps
......@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom"
"c_allreduce_sum",
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]:
pass
elif op.type == "conditional_block":
......@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps.remove_op(idx)
block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return
def _add_broadcast_allreduce(self, block):
......@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self):
if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"]
self.sharding_ring_id = 0
......@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase):
self.global_word_size,
self.sharding_group_size,
self.dp_group_size)
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing&DP mode !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
if self._as_outer_parallelism:
self.sharding_ring_id = 1
assert self.global_word_size > self._inner_parallelism_size, \
"global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
assert self.global_word_size % self._inner_parallelism_size == 0, \
"global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
self.sharding_rank = self.global_rank // self._inner_parallelism_size
self.sharding_group_size = self.role_maker._worker_num(
) // self._inner_parallelism_size
_offset = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx % self._inner_parallelism_size == _offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self.mp_group_id = 2
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if self.use_pipeline:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num(
)
self.pp_ring_id = 1
self.pp_rank = self.global_rank // self.sharding_group_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
......@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("current model parallelism group_size: {}".format(
self.mp_group_size))
logging.info("current model parallelism rank: {}".format(self.mp_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("global word endpoints: {}".format(self.endpoints))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("current model parallelism group endpoints: {}".format(
self.mp_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
return
def _initialization_broadcast(self, startup_prog):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
params.append(param)
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
......@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[
_logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0]))
updated_min_idx = idx_
......@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.debug(
_logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
......@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0
pre_segment_end_idx = -1
while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
......@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1
......@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
_logger.info("recompute segment[{}]".format(i))
_logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
_logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory
......@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
_logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
_logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory
......@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
......@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue
if name not in var_name_dict:
var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
dtype=block.program.global_block().var(name).dtype,
type=block.program.global_block().var(name).type,
persistable=block.program.global_block().var(
name).persistable,
stop_gradient=block.program.global_block().var(name)
.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册