提交 920806db 编写于 作者: S sandyhouse

update

上级 70eb21c5
...@@ -126,6 +126,9 @@ class ProgramDeps(object): ...@@ -126,6 +126,9 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -28,17 +28,20 @@ def check_broadcast(block): ...@@ -28,17 +28,20 @@ def check_broadcast(block):
if the broadcasted var has a fill_constant op, the fill_constant if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error. sync_calc op. Otherwise, raise error.
should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron)
""" """
broadcast_vars = {} broadcast_vars = {}
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_broadcast": if op.type == "c_broadcast":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name: if "@BroadCast" in var_name:
if var_name in broadcast_vars: if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}" raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}". "the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][ format(var_name, broadcast_vars[
"broadcast_pos"], idx)) var_name]["broadcast_pos"], idx))
broadcast_vars[var_name] = { broadcast_vars[var_name] = {
"fill_constant_pos": -1, "fill_constant_pos": -1,
"broadcast_pos": idx, "broadcast_pos": idx,
...@@ -61,6 +64,7 @@ def check_broadcast(block): ...@@ -61,6 +64,7 @@ def check_broadcast(block):
last_sync_calc_op_idx = idx last_sync_calc_op_idx = idx
continue continue
if op.type == "c_broadcast": if op.type == "c_broadcast":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name: if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1: if broadcast_vars[var_name]["fill_constant_pos"] != -1:
...@@ -78,7 +82,7 @@ def check_broadcast(block): ...@@ -78,7 +82,7 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block, shard, dp_ring_id=-1): def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
""" """
the op order should be: the op order should be:
grad: grad:
...@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -89,14 +93,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum) - 6: op that use Var (dp_grads & sum)
should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron)
""" """
vars_status = {} vars_status = {}
dp_grads_status = {} dp_grads_status = {}
idx_last_grad_allreduce = -1 idx_last_grad_allreduce = -1
idx_amp_allreduce = -1 idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0] param = var_name.split("@")[0]
...@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -107,7 +115,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
else: else:
dp_grads_status[var_name] = -1 dp_grads_status[var_name] = -1
if ring_id != 0: if ring_id != sharding_ring_id:
assert shard.has_param(param) assert shard.has_param(param)
assert ring_id == dp_ring_id assert ring_id == dp_ring_id
...@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -130,16 +138,18 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
if var_name in vars_status: if var_name in vars_status:
_status = vars_status[var_name] _status = vars_status[var_name]
else: else:
_status = dp_grads_status[var_name] _status = dp_grads_status[var_name]
if _status == -1: if _status == -1:
raise ValueError("{} is not generated, but you are" raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name)) "trying to all-reduce it".format(
var_name))
if _status == 0: if _status == 0:
raise ValueError("There should be a sync_calc op " raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the" "after generate Var: {} and before the"
...@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): ...@@ -159,7 +169,7 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1):
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == 0: if ring_id == sharding_ring_id:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in vars_status: if var_name in vars_status:
assert vars_status[var_name] == 2 assert vars_status[var_name] == 2
...@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx): ...@@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx):
return OpRole.Forward or OpRole.Backward return OpRole.Forward or OpRole.Backward
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = block.ops[insert_idx].attr('op_role')
if (insert_idx >= len(block.ops)) or ( #if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): # op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward # return OpRole.Backward
#if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
# return OpRole.Forward
if insert_idx >= len(block.ops): return OpRole.Optimize
if op_role == int(OpRole.Backward): return OpRole.Backward
if op_role == int(OpRole.Optimize): return OpRole.Optimize
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward return OpRole.Forward
...@@ -428,7 +443,7 @@ def comm_analyse(main_program): ...@@ -428,7 +443,7 @@ def comm_analyse(main_program):
count)) count))
def add_sync_comm(program, dist_strategy): def add_sync_comm(program, nccl_ids):
""" """
When clone a test prog by clone from the sharding main prog, When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function part of the sync_comm op maybe be pruned by mistake, this function
...@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy): ...@@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy):
#NOTE (liangjianzhong): only support one comm stream by now, use more than one #NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future. # comm streams will cause error. should be revise in future.
assert isinstance(
nccl_ids, list
), "the second argument of this function should be a list of nccl_ids"
block = program.global_block() block = program.global_block()
not_sync_vars = set([]) not_sync_vars = set([])
for op in block.ops: for op in block.ops:
...@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy): ...@@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name) not_sync_vars.remove(input_name)
if not_sync_vars: if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num): for nccl_id in nccl_ids:
block.append_op( block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)}, inputs={'X': list(not_sync_vars)},
...@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding # now only Momentum and adam are compatible with sharding
......
...@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core ...@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
...@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
"ModelParallelOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._reduced_grads_to_param = {} self._reduced_grads_to_param = {}
self._shard = Shard() self._shard = Shard()
# use sharding as outer parallelism (e.g. inner:Megatron & outer sharding)
self._as_outer_parallelism = False
self._inner_parallelism_size = None
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
return False return False
...@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase):
"fuse_broadcast_MB"] "fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[ self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"] "hybrid_dp"]
self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[
"as_outer_parallelism"]
self._inner_parallelism_size = int(
self.user_defined_strategy.sharding_configs[
"inner_parallelism_size"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
main_program = loss.block.program
main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index(
) // self.user_defined_strategy.sharding_configs[
'sharding_group_size']
main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 1
optimize_ops, params_grads, program_list = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
if self.use_pipeline:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", pp_rank)
main_program = program_list[pp_rank]['program']
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block main_block = loss.block
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
self._main_program = main_block.program self._main_program = main_block.program
self._startup_program = startup_program self._startup_program = startup_program
if self.use_pipeline:
pp_optimizer._rename_gradient_var_name(main_block)
# step1: set_up # step1: set_up
self._set_up(params_grads) self._set_up(params_grads)
...@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad # step4: insert reduce_sum for grad
insert_scale_loss_grad_ops( # grad_scale_coeff = self.role_maker._worker_num()
main_block, scale=1.0 / self.role_maker._worker_num()) # if self._as_outer_parallelism:
# grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size
# insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff)
sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size)
main_block._sync_with_cpp() main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block # step5: remove unneeded ops and vars from block
self._prune_main_program(main_block) self._prune_main_program(main_block)
self._prune_startup_program(startup_block) self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
if self.use_pipeline:
# crop ops
for idx, op in reversed(list(enumerate(main_block.ops))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
# out_name = op.output_arg_names[0]
# if not 'GRAD' in out_name: continue
# param_name = out_name.strip("@GRAD")
# #if main_block.has_var(out_name): continue
# if self._shard.has_param(param_name): continue
# main_block._remove_op(idx)
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
param_list = []
for param_name, grad_name in params_grads:
if self._shard.has_param(param_name):
param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list)
pp_optimizer._accumulate_gradients(main_block)
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
#grad_var = main_block.vars[grad_name]
#grad_var.persistable = True
#main_block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [grad_var]},
# attrs={
# 'shape': grad_var.shape,
# 'dtype': grad_var.dtype,
# 'value': float(0),
# #self._op_device_key: device,
# # a trick to run this op once per mini-batch
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
main_block._sync_with_cpp()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
# check op dependecy # check op dependecy
check_broadcast(main_block) check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id) check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase):
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank, self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True) self.sharding_ring_id, True)
# inner & outer model parallelism
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# dp # dp
if self.hybrid_dp: if self.hybrid_dp:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
# pp
if self.use_pipeline:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase):
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1] param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad) segment._allreduce_vars.append(reduced_grad)
assert ( #assert (
reduced_grad not in self._reduced_grads_to_param) # reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
...@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, self._shard)
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.mp_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id) Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id) gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard) gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps # build prog deps
...@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune # Prune
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_sync_comm_stream", "c_allreduce_sum",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom" "c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]: ]:
pass pass
elif op.type == "conditional_block": elif op.type == "conditional_block":
...@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase):
program_deps.remove_op(idx) program_deps.remove_op(idx)
block._sync_with_cpp() block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return return
def _add_broadcast_allreduce(self, block): def _add_broadcast_allreduce(self, block):
...@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self): def _init_comm(self):
if self.hybrid_dp: if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"] "sharding_group_size"]
self.sharding_ring_id = 0 self.sharding_ring_id = 0
...@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase):
self.global_word_size, self.global_word_size,
self.sharding_group_size, self.sharding_group_size,
self.dp_group_size) self.dp_group_size)
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing&DP mode !") logging.info("Using Sharing&DP mode !")
else:
if self._as_outer_parallelism:
self.sharding_ring_id = 1
assert self.global_word_size > self._inner_parallelism_size, \
"global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
assert self.global_word_size % self._inner_parallelism_size == 0, \
"global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
self.sharding_rank = self.global_rank // self._inner_parallelism_size
self.sharding_group_size = self.role_maker._worker_num(
) // self._inner_parallelism_size
_offset = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx % self._inner_parallelism_size == _offset
]
# the current entire model parallelism group is the combination of innert & sharding parallelism
self.mp_group_id = 2
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
# print(
# "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer"
# )
# partition_idx = self.global_rank // self._inner_parallelism_size
# magetron_endpoints = self.endpoints[
# partition_idx * self._inner_parallelism_size:partition_idx *
# self._inner_parallelism_size + self._inner_parallelism_size]
# magetron_rank = self.global_rank % self._inner_parallelism_size
# self._collective_helper._init_communicator(
# program=self._startup_program,
# current_endpoint=self.current_endpoint,
# endpoints=magetron_endpoints,
# rank=magetron_rank,
# ring_id=0,
# wait_port=True)
# logging.info("megatron group size: {}".format(
# self._inner_parallelism_size))
# logging.info("megatron rank: {}".format(magetron_rank))
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if self.use_pipeline:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num(
)
self.pp_ring_id = 1
self.pp_rank = self.global_rank // self.sharding_group_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !")
else: else:
self.sharding_ring_id = 0 self.sharding_ring_id = 0
self.sharding_rank = self.global_rank self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num() self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints self.sharding_group_endpoints = self.endpoints
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
self.dp_group_size = None self.dp_group_size = None
...@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("global rank: {}".format(self.global_rank)) logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size)) logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank)) logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("current model parallelism group_size: {}".format(
self.mp_group_size))
logging.info("current model parallelism rank: {}".format(self.mp_rank))
logging.info("dp group size: {}".format(self.dp_group_size)) logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank)) logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint)) logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("global word endpoints: {}".format(self.endpoints))
logging.info("sharding group endpoints: {}".format( logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints)) self.sharding_group_endpoints))
logging.info("current model parallelism group endpoints: {}".format(
self.mp_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
return return
def _initialization_broadcast(self, startup_prog):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
params.append(param)
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
...@@ -115,7 +115,7 @@ class ProgramStats(object): ...@@ -115,7 +115,7 @@ class ProgramStats(object):
updated_min_idx = min_idx updated_min_idx = min_idx
while idx_ > pre_segment_end_idx: while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]): if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[ _logger.info("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0])) 0]))
updated_min_idx = idx_ updated_min_idx = idx_
...@@ -155,7 +155,7 @@ class ProgramStats(object): ...@@ -155,7 +155,7 @@ class ProgramStats(object):
sorted_checkpoints = [] sorted_checkpoints = []
for name in checkpoints_name: for name in checkpoints_name:
if name not in self.var_op_deps: if name not in self.var_op_deps:
_logger.debug( _logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name) % name)
elif self.var_op_deps[name]["var_as_output_ops"] == []: elif self.var_op_deps[name]["var_as_output_ops"] == []:
...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): ...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_(
start_idx = 0 start_idx = 0
pre_segment_end_idx = -1 pre_segment_end_idx = -1
while True: while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1: if start_idx >= len(checkpoints_name) - 1:
break break
# min_idx: checkpoint_1' s input op # min_idx: checkpoint_1' s input op
...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_(
min_idx = program_stat._update_segment_start( min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else:
_logger.info("Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1))
start_idx += 1 start_idx += 1
...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_( ...@@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_(
recompute_segments = segments recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments): for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i)) _logger.info("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names())) ), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[ _logger.info("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory # 2) go through all forward ops and induct all variables that will be hold in memory
...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_( ...@@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name) cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars)) len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory # b. output of seed op should be kept in memory
...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops) max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []: if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position] gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops): for op in reversed(gap_ops):
...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_( ...@@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_(
continue continue
if name not in var_name_dict: if name not in var_name_dict:
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
dtype=block.program.global_block().var(name).dtype,
type=block.program.global_block().var(name).type,
persistable=block.program.global_block().var(
name).persistable,
stop_gradient=block.program.global_block().var(name)
.stop_gradient)
# 3.a. add ops in current recompute_segment as forward recomputation ops # 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory) vars_in_memory)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册