提交 e79f24f8 编写于 作者: S sandyhouse

add sync for c_broadcast

上级 779fde8d
......@@ -37,6 +37,7 @@ message ShardingConfig {
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
optional int32 pp_bz = 9 [ default = 1 ];
}
message AMPConfig {
......
......@@ -98,6 +98,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"acc_steps"]
self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"]
self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"]
if self.inner_opt is None:
raise ValueError(
......@@ -108,6 +109,7 @@ class ShardingOptimizer(MetaOptimizerBase):
main_program = loss.block.program
main_program._pipeline_opt = dict()
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
main_program._pipeline_opt['pp_bz'] = self.pp_bz
pp_rank = self.role_maker._worker_index() // (
self.user_defined_strategy.sharding_configs[
'sharding_group_size'] * self._inner_parallelism_size)
......
......@@ -4416,7 +4416,11 @@ class PipelineOptimizer(object):
var = block.var(var_name)
# skip data, because we will process it later
if var.is_data: continue
prev_device = None
if var_name in self._param_device_map:
prev_device = self._param_device_map[var_name]
prev_op = self._find_real_prev_op(block.ops, op, var_name)
if not pre_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if not prev_device or prev_device == 'gpu:all': continue
......@@ -4494,6 +4498,20 @@ class PipelineOptimizer(object):
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
#self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
......@@ -4508,7 +4526,7 @@ class PipelineOptimizer(object):
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
fill_shape[0] = self.pp_bz
block._insert_op(
index=index + extra_index,
#type='recv_v2',
......@@ -4523,6 +4541,19 @@ class PipelineOptimizer(object):
'value': float(0.0),
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
#type='recv_v2',
......@@ -4591,8 +4622,12 @@ class PipelineOptimizer(object):
# continue
#input_var_to_device[var_name].append(cur_device)
prev_device = None
generate_ops = output_var_to_op.get(var_name)
if generate_ops is None: continue
if generate_ops is None:
if var_name not in self._param_device_map:
continue
prev_device = self._param_device_map[var_name]
prev_op = None
for gen_op, gen_idx in reversed(generate_ops):
......@@ -4600,6 +4635,7 @@ class PipelineOptimizer(object):
prev_op = gen_op
break
if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
......@@ -5134,6 +5170,7 @@ class PipelineOptimizer(object):
if 'schedule_mode' in main_block.program._pipeline_opt:
schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self.schedule_mode = schedule_mode
self.pp_bz = main_block.program._pipeline_opt['pp_bz']
self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt:
......@@ -5175,15 +5212,117 @@ class PipelineOptimizer(object):
# send and recv ops for data var.
main_program = main_block.program
program_list = self._split_program(main_program, device_list)
#cur_device_index = 0
#device_num = len(program_list)
for p in program_list:
self._create_vars(p["program"].block(0), main_block)
# # Add send/recv pair to sync the execution.
# block = p['program'].block(0)
# prev_device_index = cur_device_index - 1
# next_device_index = cur_device_index + 1
# add_send_for_forward = False
# add_send_for_backward = False
# add_recv_for_backward = False
# extra_index = 0
# new_var = block.create_var(
# name=unique_name.generate('sync'),
# shape=[1],
# dtype='float32',
# persistable=False,
# stop_gradient=True)
# block._insert_op(
# index=0,
# type='fill_constant',
# inputs={},
# outputs={'Out': [new_var]},
# attrs={
# 'shape': [1],
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'value': float(0.0),
# })
# extra_index += 1
# for op_idx, op in enumerate(list(block.ops)):
# if op_idx == extra_index:
# if cur_device_index > 0:
# pair_key = prev_device_index * 1000 + cur_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# continue
# if op.type == "send_v2" and self._is_forward_op(op) \
# and not add_send_for_forward \
# and cur_device_index < device_num - 1:
# add_send_for_forward = True
# pair_key = cur_device_index * 1000 + next_device_index
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# inputs={'Out': new_var},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Forward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# extra_index += 1
# if self._is_backward_op(op) and not add_recv_for_backward \
# and cur_device_index < device_num - 1:
# pair_key = next_device_index * 1000 + cur_device_index
# add_recv_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='recv_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 0,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# if op.type == "send_v2" and self._is_backward_op(op) \
# and not add_send_for_backward \
# and cur_device_index > 0:
# pair_key = cur_device_index * 1000 + prev_device_index
# add_send_for_backward = True
# ring_id = self._pp_ring_map[pair_key]
# block._insert_op(
# index=op_idx + extra_index,
# type='send_v2',
# outputs={'Out': [new_var]},
# attrs={
# 'out_shape': new_var.shape,
# 'dtype': new_var.dtype,
# self._op_role_key: self._op_role.Backward,
# 'peer': 1,
# 'use_calc_stream': True,
# 'ring_id': ring_id,
# })
# cur_device_index += 1
#self._insert_sendrecv_for_data_var(main_block, program_list,
# startup_program, device_list)
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list)
#self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册