提交 c7472f16 编写于 作者: R root 提交者: sandyhouse

update

上级 eeca5ef6
...@@ -38,6 +38,7 @@ message ShardingConfig { ...@@ -38,6 +38,7 @@ message ShardingConfig {
optional int32 acc_steps = 7 [ default = 1 ]; optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ]; optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ]; optional int32 pp_bz = 9 [ default = 1 ];
optional bool pp_allreduce_in_optimize = 10 [ default = true ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -88,7 +88,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -88,7 +88,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
grad: grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum_sharding - 2: reduce_sum_sharding (allreduce --> reduce)
- 3: sync_comm - 3: sync_comm
- 4: allreuce_sum_dp (dp_grads) - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads) - 5: sync_comm (dp_grads)
...@@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -103,7 +103,7 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
idx_gradient_clip_allreduce = -1 idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False: if op.all_attrs()["use_calc_stream"] == False:
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
...@@ -137,11 +137,12 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -137,11 +137,12 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
var_name] == 0: var_name] == 0:
dp_grads_status[var_name] = 1 dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False: if op.all_attrs()["use_calc_stream"] == False:
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id") ring_id = op.desc.attr("ring_id")
if ring_id == sharding_ring_id: if ring_id == sharding_ring_id:
assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce"
if var_name in vars_status: if var_name in vars_status:
_status = vars_status[var_name] _status = vars_status[var_name]
else: else:
...@@ -191,6 +192,9 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): ...@@ -191,6 +192,9 @@ def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1):
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
input_name)) input_name))
raise ValueError(
"The reduce output grad [{}] should NOT be be used in Non-root rank.".
format(input_name))
if input_name in dp_grads_status: if input_name in dp_grads_status:
if dp_ring_id == -1: if dp_ring_id == -1:
if dp_grads_status[input_name] != 3: if dp_grads_status[input_name] != 3:
...@@ -352,7 +356,9 @@ def get_grad_device(grad_name, shard): ...@@ -352,7 +356,9 @@ def get_grad_device(grad_name, shard):
grad_name) grad_name)
base_name = None base_name = None
# mind the traversal order # mind the traversal order
possible_suffixes = ['.cast_fp16@GRAD', '@GRAD'] possible_suffixes = [
'.cast_fp16@GRAD_0', '.cast_fp16@GRAD', '@GRAD_0', '@GRAD'
]
for suffix in possible_suffixes: for suffix in possible_suffixes:
if suffix in grad_name: if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name) base_name = re.sub(suffix, '', grad_name)
...@@ -369,7 +375,7 @@ def insert_reduce_ops(block, ...@@ -369,7 +375,7 @@ def insert_reduce_ops(block,
ring_id, ring_id,
reduce_vars, reduce_vars,
shard, shard,
op_role, op_role=OpRole.Backward,
use_calc_stream=False): use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
...@@ -389,10 +395,18 @@ def insert_reduce_ops(block, ...@@ -389,10 +395,18 @@ def insert_reduce_ops(block,
'use_calc_stream': use_calc_stream, 'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role OP_ROLE_KEY: op_role
}) })
return return
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
......
...@@ -100,6 +100,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -100,6 +100,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.schedule_mode = self.user_defined_strategy.sharding_configs[ self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"] "schedule_mode"]
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
...@@ -179,6 +181,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -179,6 +181,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._initialization_broadcast(startup_program) self._initialization_broadcast(startup_program)
if self.use_pipeline: if self.use_pipeline:
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops # crop ops
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
# if op.type == 'fill_constant' and int(op.attr('op_role')) == 16: # if op.type == 'fill_constant' and int(op.attr('op_role')) == 16:
...@@ -207,8 +210,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -207,8 +210,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# param_list.append(param_name) # param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list) #pp_optimizer._clear_gradients(main_block, param_list)
accumulated_grad_names = pp_optimizer._accumulate_gradients( accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block) main_block,
pp_allreduce_in_optimize=self.pp_allreduce_in_optimize)
# accumulated_grad_names = sorted(accumulated_grad_names) # accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ") print("persistable FP32 grad: ")
print(accumulated_grad_names) print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
...@@ -540,7 +545,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -540,7 +545,10 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name))
# find reduce vars # find reduce vars
if not self.use_pipeline: if self.use_pipeline and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize
pass
else:
if is_backward_op(op) and \ if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names: OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
...@@ -678,7 +686,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -678,7 +686,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding # sharding
if self.use_pipeline: if self.use_pipeline and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)): for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0 assert len(self._segments[idx]._allreduce_vars) == 0
...@@ -693,9 +701,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -693,9 +701,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, # allreduce --> reduce
insert_reduce_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = self._segments[
...@@ -775,8 +789,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -775,8 +789,15 @@ class ShardingOptimizer(MetaOptimizerBase):
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# sharding # sharding
insert_allreduce_ops(block, segment._start_idx, # allreduce --> reduce
self.sharding_ring_id, allreduce_vars) insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -829,12 +850,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -829,12 +850,6 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_comm(self): def _init_comm(self):
# sharding alone mode
# self.sharding_ring_id = 0
# self.sharding_rank = self.global_rank
# self.sharding_group_endpoints = self.endpoints[:]
# self.sharding_group_size = len(self.endpoints)
if self.hybrid_dp: if self.hybrid_dp:
assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism"
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
...@@ -854,7 +869,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -854,7 +869,6 @@ class ShardingOptimizer(MetaOptimizerBase):
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank if (idx % self.sharding_group_size) == self.sharding_rank
] ]
# self.global_group_endpoints = self.role_maker._get_trainer_endpoints()[:]
assert self.global_word_size > self.sharding_group_size, \ assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
......
...@@ -4838,7 +4838,7 @@ class PipelineOptimizer(object): ...@@ -4838,7 +4838,7 @@ class PipelineOptimizer(object):
new_var.persistable = False new_var.persistable = False
self._rename_arg(op, grad_name, new_grad_var_name) self._rename_arg(op, grad_name, new_grad_var_name)
def _accumulate_gradients(self, block): def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
""" """
Accumulate the gradients generated in microbatch to the one in mini-batch. Accumulate the gradients generated in microbatch to the one in mini-batch.
""" """
...@@ -4875,7 +4875,11 @@ class PipelineOptimizer(object): ...@@ -4875,7 +4875,11 @@ class PipelineOptimizer(object):
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
offset = 0 offset = 0
param_name = op_role_var[i] param_name = op_role_var[i]
# if not block.has_var(param_name): continue
if not pp_allreduce_in_optimize:
if not block.has_var(param_name):
continue
if '@BroadCast' in param_name: if '@BroadCast' in param_name:
param_name = param_name[0:param_name.find('@BroadCast')] param_name = param_name[0:param_name.find('@BroadCast')]
# clear gradient # clear gradient
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册