提交 c7472f16 编写于 作者: R root 提交者: sandyhouse

update

上级 eeca5ef6
...@@ -38,6 +38,7 @@ message ShardingConfig { ...@@ -38,6 +38,7 @@ message ShardingConfig {
optional int32 acc_steps = 7 [ default = 1 ]; optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ]; optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ]; optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = true ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -88,7 +88,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -88,7 +88,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
grad: grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum_sharding - 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm - 3: sync_comm
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
...@@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False: if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
...@@ -137,11 +137,12 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -137,11 +137,12 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
var_name] == 0: var_name] == 0:
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False: if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == sharding_ring_id: if ring_id == sharding_ring_id:
assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
if var_name in vars_status: if var_name in vars_status:
_status = vars_status[var_name] _status = vars_status[var_name]
else: else:
...@@ -191,6 +192,9 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -191,6 +192,9 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
input_name)) input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status: if input_name in dp_grads_status:
if dp_ring_id == -1: if dp_ring_id == -1:
if dp_grads_status[input_name] != 3: if dp_grads_status[input_name] != 3:
...@@ -352,7 +356,9 @@ def get_grad_device(grad_name, shard): ...@@ -352,7 +356,9 @@ def get_grad_device(grad_name, shard):
grad_name) grad_name)
base_name = None base_name = None
# mind the traversal order # mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD'] possible_suffixes = [
'.cast_fp16@GRAD_0', '.cast_fp16@GRAD', '@GRAD_0', '@GRAD'
]
for suffix in possible_suffixes: for suffix in possible_suffixes:
if suffix in grad_name: if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name) base_name = re.sub(suffix, '', grad_name)
...@@ -369,7 +375,7 @@ def insert_reduce_ops(block, ...@@ -369,7 +375,7 @@ def insert_reduce_ops(block,
ring_id, ring_id,
reduce_vars, reduce_vars,
shard, shard,
op_role, op_role=OpRole.Backward,
use_calc_stream=False): use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
...@@ -389,10 +395,18 @@ def insert_reduce_ops(block, ...@@ -389,10 +395,18 @@ def insert_reduce_ops(block,
'use_calc_stream': use_calc_stream, 'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role OP_ROLE_KEY: op_role
}) })
return return
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
......
...@@ -100,6 +100,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -100,6 +100,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.schedule_mode = self.user_defined_strategy.sharding_configs[ self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"] "schedule_mode"]
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -179,6 +181,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -179,6 +181,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._initialization_broadcast(startup_program) self._initialization_broadcast(startup_program)
if self.use_pipeline: if self.use_pipeline:
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops # crop ops
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16: # if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
...@@ -207,20 +210,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -207,20 +210,22 @@ class ShardingOptimizer(MetaOptimizerBase):
# param_list.append(param_name) # param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list) #pp_optimizer._clear_gradients(main_block, param_list)
accumulated_grad_names = pp_optimizer._accumulate_gradients( accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block, main_block,
first_optimize_op_index, pp_allreduce_in_optimize=self.pp_allreduce_in_optimize)
self.sharding_ring_id, # accumulated_grad_names = sorted(accumulated_grad_names)
accumulated_grad_names, if self.pp_allreduce_in_optimize:
self._shard, print("persistable FP32 grad: ")
core.op_proto_and_checker_maker.OpRole.Optimize, print(accumulated_grad_names)
use_calc_stream=True) first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
#if not self._shard.has_param(param_name): continue #if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue ##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name) #assert main_block.has_var(grad_name)
...@@ -240,130 +245,130 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -240,130 +245,130 @@ class ShardingOptimizer(MetaOptimizerBase):
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched, # 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# }) # })
#def _create_var(block, ref_var, name): #def _create_var(block, ref_var, name):
# """ # """
# Create a new var for block, which has the same type, # Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the # shape and dtype as ref_var, then rename it with the
# name `name`. # name `name`.
# """ # """
# new_var = block.create_var( # new_var = block.create_var(
# name=name, # name=name,
# shape=ref_var.shape, # shape=ref_var.shape,
# dtype=ref_var.dtype, # dtype=ref_var.dtype,
# type=ref_var.type, # type=ref_var.type,
# lod_level=ref_var.lod_level, # lod_level=ref_var.lod_level,
# persistable=ref_var.persistable, # persistable=ref_var.persistable,
# is_data=ref_var.is_data, # is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed()) # need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient # new_var.stop_gradient = ref_var.stop_gradient
# return new_var # return new_var
#def _rename_arg(op, old_name, new_name): #def _rename_arg(op, old_name, new_name):
# op_desc = op.desc # op_desc = op.desc
# if isinstance(op_desc, tuple): # if isinstance(op_desc, tuple):
# op_desc = op_desc[0] # op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name) # op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name) # op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads) #print("params_grads:", params_grads)
#for param_name, grad_name in params_grads: #for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue # if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue # #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name) # assert main_block.has_var(grad_name)
# use_fp16 = False # use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD' # fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name): # if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name] # fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True # use_fp16 = True
# grad_var = main_block.vars[grad_name] # grad_var = main_block.vars[grad_name]
# if use_fp16: # if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate( # cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name) # grad_name)
# cast_var = _create_var(main_block, fp16_grad_var, # cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name) # cast_grad_var_name)
# cast_var.persistable = False # cast_var.persistable = False
# main_block.append_op( # main_block.append_op(
# #index=offset + 1, # #index=offset + 1,
# type='cast', # type='cast',
# inputs={'X': grad_var}, # inputs={'X': grad_var},
# outputs={'Out': cast_var}, # outputs={'Out': cast_var},
# attrs={ # attrs={
# 'in_dtype': grad_var.dtype, # 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype, # 'out_dtype': cast_var.dtype,
# 'op_role': # 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward, # core.op_proto_and_checker_maker.OpRole.Backward,
# }) # })
# #offset += 1 # #offset += 1
# main_block.append_op( # main_block.append_op(
# #index=offset + 1, # #index=offset + 1,
# type='sum', # type='sum',
# inputs={'X': [fp16_grad_var, cast_var]}, # inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var}, # outputs={'Out': fp16_grad_var},
# attrs={ # attrs={
# 'op_role': # 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward, # core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var # 'op_role_var': op_role_var
# }) # })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))): # for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index # offset = index
# if is_backward_op(op) and ( # if is_backward_op(op) and (
# 'op_role_var' in op.attr_names): # 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var'] # op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0: # if len(op_role_var) == 0:
# continue # continue
# assert len(op_role_var) % 2 == 0 # assert len(op_role_var) % 2 == 0
# offset = index # offset = index
# for i in range(0, len(op_role_var), 2): # for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1] # grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue # if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name] # grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name: # if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name) # new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var, # new_var = _create_var(main_block, grad_var,
# new_grad_var_name) # new_grad_var_name)
# new_var.persistable = False # new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name) # _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op( # main_block._insert_op(
# index=offset + 1, # index=offset + 1,
# type='sum', # type='sum',
# inputs={'X': [grad_var, new_var]}, # inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var}, # outputs={'Out': grad_var},
# attrs={ # attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var # 'op_role_var': op_role_var
# }) # })
# offset += 1 # offset += 1
# if 'cast_fp16' in grad_name: # if 'cast_fp16' in grad_name:
# param_name = op_role_var[i] # param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD" # fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name] # fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate( # cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name) # fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var, # cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name) # cast_grad_var_name)
# cast_var.persistable = False # cast_var.persistable = False
# main_block._insert_op( # main_block._insert_op(
# index=offset + 1, # index=offset + 1,
# type='cast', # type='cast',
# inputs={'X': fp32_grad_var}, # inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var}, # outputs={'Out': cast_var},
# attrs={ # attrs={
# 'in_dtype': fp32_grad_var.dtype, # 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype, # 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var # # self._op_role_var_key: op_role_var
# }) # })
# offset += 1 # offset += 1
# main_block._insert_op( # main_block._insert_op(
# index=offset + 1, # index=offset + 1,
# type='sum', # type='sum',
# inputs={'X': [grad_var, cast_var]}, # inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var}, # outputs={'Out': grad_var},
# attrs={ # attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var}) # 'op_role_var': op_role_var})
main_block._sync_with_cpp() main_block._sync_with_cpp()
with open("start_sharding_%d" % self.role_maker._worker_index(), with open("start_sharding_%d" % self.role_maker._worker_index(),
...@@ -540,7 +545,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -540,7 +545,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name))
# find reduce vars # find reduce vars
if not self.use_pipeline: if self.use_pipeline and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize
pass
else:
if is_backward_op(op) and \ if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names: OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
...@@ -678,7 +686,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -678,7 +686,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding # sharding
if self.use_pipeline: if self.use_pipeline and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)): for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0 assert len(self._segments[idx]._allreduce_vars) == 0
...@@ -693,9 +701,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -693,9 +701,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, # allreduce --> reduce
self.sharding_ring_id, insert_reduce_ops(
self._segments[-1]._allreduce_vars) block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = self._segments[
...@@ -775,8 +789,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -775,8 +789,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# sharding # sharding
insert_allreduce_ops(block, segment._start_idx, # allreduce --> reduce
self.sharding_ring_id, allreduce_vars) insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -829,12 +850,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -829,12 +850,6 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self): def _init_comm(self):
# sharding alone mode
# self.sharding_ring_id = 0
# self.sharding_rank = self.global_rank
# self.sharding_group_endpoints = self.endpoints[:]
# self.sharding_group_size = len(self.endpoints)
if self.hybrid_dp: if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
...@@ -854,7 +869,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -854,7 +869,6 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank if (idx % self.sharding_group_size) == self.sharding_rank
] ]
# self.global_group_endpoints = self.role_maker._get_trainer_endpoints()[:]
assert self.global_word_size > self.sharding_group_size, \ assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
......
...@@ -4838,7 +4838,7 @@ class PipelineOptimizer(object): ...@@ -4838,7 +4838,7 @@ class PipelineOptimizer(object):
new_var.persistable = False new_var.persistable = False
self._rename_arg(op, grad_name, new_grad_var_name) self._rename_arg(op, grad_name, new_grad_var_name)
def _accumulate_gradients(self, block): def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
""" """
Accumulate the gradients generated in microbatch to the one in mini-batch. Accumulate the gradients generated in microbatch to the one in mini-batch.
""" """
...@@ -4875,7 +4875,11 @@ class PipelineOptimizer(object): ...@@ -4875,7 +4875,11 @@ class PipelineOptimizer(object):
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
offset = 0 offset = 0
param_name = op_role_var[i] param_name = op_role_var[i]
# if not block.has_var(param_name): continue
if not pp_allreduce_in_optimize:
if not block.has_var(param_name):
continue
if '@BroadCast' in param_name: if '@BroadCast' in param_name:
param_name = param_name[0:param_name.find('@BroadCast')] param_name = param_name[0:param_name.find('@BroadCast')]
# clear gradient # clear gradient
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册