提交 a6879219 编写于 作者: W WangXi 提交者: sandyhouse

pipeline sequential

上级 e79f24f8
......@@ -4645,143 +4645,128 @@ class PipelineOptimizer(object):
if var_name not in input_var_to_device:
input_var_to_device[var_name] = []
if cur_device in input_var_to_device[var_name]:
if (cur_device, prev_device) in input_var_to_device[var_name]:
continue
input_var_to_device[var_name].append(cur_device)
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if cur_device_index > prev_device_index:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
if self.schedule_mode == 0: # GPipe
device_type = cur_device.split(':')[0] + ':'
def _insert_send_recv(cur_id, prev_id):
nonlocal extra_index
cur_dev = device_type + str(cur_id)
prev_dev = device_type + str(prev_id)
if (cur_dev, prev_dev) in input_var_to_device[var_name]:
return
if cur_id - prev_id > 1:
_insert_send_recv(cur_id - 1, prev_id)
_insert_send_recv(cur_id, cur_id - 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
elif cur_id - prev_id < -1:
_insert_send_recv(cur_id + 1, prev_id)
_insert_send_recv(cur_id, cur_id + 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
assert abs(cur_id - prev_id) == 1
input_var_to_device[var_name].append((cur_dev, prev_dev))
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
pair = (prev_id, cur_id)
pair_key = prev_id * 1000 + cur_id
if cur_id > prev_id:
ring_id = self.ring_id + cur_id - prev_id - 1
else:
ring_id = self.ring_id + 2 + prev_id - cur_id - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_calc_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='send_v2',
type="c_broadcast",
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_device,
self._op_device_key: prev_dev,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
'use_calc_stream': False,
'ring_id': ring_id,
'root': 0,
})
extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 4
block._insert_op_without_sync(
index=index + extra_index,
type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_device_key: cur_dev,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: cur_dev,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
'root': 0,
'ring_id': ring_id,
})
extra_index += 1
continue
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
#type='send_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
# block._insert_op_without_sync(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_dev,
# self._op_role_key: op_role,
# 'ring_id': ring_id,
# })
# extra_index += 1
_insert_send_recv(
int(cur_device.split(':')[1]),
int(prev_device.split(':')[1]))
block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册