提交 e166873b 编写于 作者: S sandyhouse

update

上级 a97b9df0
......@@ -4046,7 +4046,7 @@ class PipelineOptimizer(object):
"""
prev_op = []
for op in ops:
if op.type == 'send_v2' or op.type == 'recv_v2':
if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast':
continue
if op == cur_op:
break
......@@ -4434,6 +4434,39 @@ class PipelineOptimizer(object):
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
if self.schedule_mode == 0: # GPipe
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
block._insert_op(
index=index + extra_index,
#type='send_v2',
......@@ -4452,19 +4485,19 @@ class PipelineOptimizer(object):
'root': 0,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward,
'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
})
extra_index += 1
#block._insert_op(
# index=index + extra_index,
# type='c_sync_comm_stream',
# inputs={'X': [var]},
# outputs={'Out': [var]},
# attrs={
# self._op_device_key: cur_device,
# self._op_role_key:
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'ring_id': self.ring_id,
# #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
# })
#extra_index += 1
fill_shape = list(var.shape)
fill_shape[0] = 1
block._insert_op(
......@@ -4509,6 +4542,7 @@ class PipelineOptimizer(object):
outputs={'Out': [var]},
attrs={
self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role,
'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
......@@ -4987,6 +5021,10 @@ class PipelineOptimizer(object):
and 'local_rank' in main_block.program._pipeline_opt, \
'Please use pipeline with fleet.'
local_rank = main_block.program._pipeline_opt['local_rank']
schedule_mode = 0
if 'schedule_mode' in main_block.program._pipeline_opt:
schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self.schedule_mode = schedule_mode
self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt:
......@@ -5074,6 +5112,7 @@ class PipelineOptimizer(object):
"inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank,
"schedule_mode": schedule_mode,
"section_program": program_list[local_rank],
"place": place_list[local_rank],
"place_id": place_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册