提交 af17a6ee 编写于 作者: W WangXi 提交者: sandyhouse

optimizer pp insert sendrecv from O(n2) to O(n)

上级 5646f710
......@@ -4515,6 +4515,96 @@ class PipelineOptimizer(object):
})
extra_index += 1
def _xx_insert_sendrecv_ops_for_boundaries(self, block):
"""
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index = 0
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
input_var_to_device = dict()
# A map from output var to op which generate it.
output_var_to_op = dict()
for index, op in enumerate(list(block.ops)):
for var_name in op.output_arg_names:
ops = output_var_to_op.setdefault(var_name, [])
ops.append([op, index])
for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key)
if cur_device == "gpu:all": continue
for var_name in op.input_arg_names:
var = block.var(var_name)
if var.is_data: continue
#if var_name not in input_var_to_device:
# input_var_to_device[var_name] = []
#if cur_device in input_var_to_device[var_name]:
# continue
#input_var_to_device[var_name].append(cur_device)
generate_ops = output_var_to_op.get(var_name)
if generate_ops is None: continue
prev_op = None
for gen_op, gen_idx in reversed(generate_ops):
if gen_idx < index:
prev_op = gen_op
break
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if prev_device is None or prev_device == 'gpu:all': continue
if prev_device == cur_device: continue
if var_name not in input_var_to_device:
input_var_to_device[var_name] = []
if cur_device in input_var_to_device[var_name]:
continue
input_var_to_device[var_name].append(cur_device)
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
#block._insert_op(
block._insert_op_without_sync(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id,
})
extra_index += 1
#block._insert_op(
block._insert_op_without_sync(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id,
})
extra_index += 1
block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names):
"""
Clear gradients at the begining of each run of a minibatch.
......@@ -4932,7 +5022,7 @@ class PipelineOptimizer(object):
"another in the order of their ids.")
# Step2: add send and recv ops between section boundaries
self._insert_sendrecv_ops_for_boundaries(main_block)
self._xx_insert_sendrecv_ops_for_boundaries(main_block)
# Step3: split program into sections and add pairs of
# send and recv ops for data var.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册