提交 a6879219 编写于 作者: W WangXi 提交者: sandyhouse

pipeline sequential

上级 e79f24f8
...@@ -4645,143 +4645,128 @@ class PipelineOptimizer(object): ...@@ -4645,143 +4645,128 @@ class PipelineOptimizer(object):
if var_name not in input_var_to_device: if var_name not in input_var_to_device:
input_var_to_device[var_name] = [] input_var_to_device[var_name] = []
if cur_device in input_var_to_device[var_name]: if (cur_device, prev_device) in input_var_to_device[var_name]:
continue continue
input_var_to_device[var_name].append(cur_device)
device_type = cur_device.split(':')[0] + ':'
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name] def _insert_send_recv(cur_id, prev_id):
prev_device_index = int(prev_device.split(':')[1]) nonlocal extra_index
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index) cur_dev = device_type + str(cur_id)
pair_key = prev_device_index * 1000 + cur_device_index prev_dev = device_type + str(prev_id)
if cur_device_index > prev_device_index: if (cur_dev, prev_dev) in input_var_to_device[var_name]:
ring_id = self.ring_id + cur_device_index - prev_device_index - 1 return
else:
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 if cur_id - prev_id > 1:
print("call xx_insert, schedule_mode:", self.schedule_mode) _insert_send_recv(cur_id - 1, prev_id)
if self.schedule_mode == 0: # GPipe _insert_send_recv(cur_id, cur_id - 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
elif cur_id - prev_id < -1:
_insert_send_recv(cur_id + 1, prev_id)
_insert_send_recv(cur_id, cur_id + 1)
input_var_to_device[var_name].append(
(cur_dev, prev_dev))
return
assert abs(cur_id - prev_id) == 1
input_var_to_device[var_name].append((cur_dev, prev_dev))
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
pair = (prev_id, cur_id)
pair_key = prev_id * 1000 + cur_id
if cur_id > prev_id:
ring_id = self.ring_id + cur_id - prev_id - 1
else:
ring_id = self.ring_id + 2 + prev_id - cur_id - 1
print("call xx_insert, schedule_mode:", self.schedule_mode)
assert self.schedule_mode == 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id))
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_calc_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: op_role,
})
extra_index += 1
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index, index=index + extra_index,
type='send_v2', type="c_broadcast",
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var},
attrs={ attrs={
self._op_device_key: prev_device, self._op_device_key: prev_dev,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': False,
'peer': cur_device_index, 'ring_id': ring_id,
'ring_id': self.ring_id 'root': 0,
if cur_device_index > prev_device_index else
self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 4
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index, index=index + extra_index,
type='recv_v2', type='fill_constant',
inputs={},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'shape': fill_shape,
'dtype': var.dtype, 'dtype': var.dtype,
self._op_device_key: cur_device, self._op_device_key: cur_dev,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
self._op_device_key: cur_dev,
self._op_role_key: op_role, self._op_role_key: op_role,
'use_calc_stream': True, 'use_calc_stream': True,
'peer': prev_device_index, 'root': 0,
'ring_id': self.ring_id 'ring_id': ring_id,
if cur_device_index > prev_device_index else
self.ring_id + 2,
}) })
extra_index += 1 extra_index += 1
continue
assert self.schedule_mode == 1 # block._insert_op_without_sync(
if pair not in self._pipeline_pair: # index=index + extra_index,
self._pipeline_pair.append(pair) # type='c_sync_comm_stream',
self._pp_ring_map[pair_key] = self.ring_id # inputs={'X': [var]},
ring_id = self.ring_id # outputs={'Out': [var]},
self.ring_id += 1 # attrs={
else: # self._op_device_key: cur_dev,
ring_id = self._pp_ring_map[pair_key] # self._op_role_key: op_role,
print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id)) # 'ring_id': ring_id,
block._insert_op_without_sync( # })
index=index + extra_index, # extra_index += 1
#type='send_v2',
type='c_broadcast', _insert_send_recv(
inputs={'X': var}, int(cur_device.split(':')[1]),
outputs={'Out': var}, int(prev_device.split(':')[1]))
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': cur_device_index,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
'ring_id': ring_id,
#'ring_id': self.ring_id,
#'root': prev_device_index,
'root': 0,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='fill_constant',
inputs={},
outputs={'Out': [var]},
attrs={
'shape': fill_shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'value': float(0.0),
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
#type='recv_v2',
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
#'out_shape': var.shape,
#'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': False,
#'peer': prev_device_index,
#'root': prev_device_index,
'root': 0,
#'ring_id': self.ring_id,
'ring_id': ring_id,
#'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2,
#'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._insert_op_without_sync(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
block._sync_with_cpp() block._sync_with_cpp()
def _clear_gradients(self, main_block, param_names): def _clear_gradients(self, main_block, param_names):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册