提交 e166873b 编写于 作者: S sandyhouse

update

上级 a97b9df0
...@@ -4046,7 +4046,7 @@ class PipelineOptimizer(object): ...@@ -4046,7 +4046,7 @@ class PipelineOptimizer(object):
""" """
prev_op = [] prev_op = []
for op in ops: for op in ops:
if op.type == 'send_v2' or op.type == 'recv_v2': if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast':
continue continue
if op == cur_op: if op == cur_op:
break break
...@@ -4434,6 +4434,39 @@ class PipelineOptimizer(object): ...@@ -4434,6 +4434,39 @@ class PipelineOptimizer(object):
ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1
if pair not in self._pipeline_pair: if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair) self._pipeline_pair.append(pair)
if self.schedule_mode == 0: # GPipe
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'ring_id': self.ring_id
if cur_device_index > prev_device_index else
self.ring_id + 2,
})
extra_index += 1
continue
assert self.schedule_mode == 1
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
#type='send_v2', #type='send_v2',
...@@ -4452,19 +4485,19 @@ class PipelineOptimizer(object): ...@@ -4452,19 +4485,19 @@ class PipelineOptimizer(object):
'root': 0, 'root': 0,
}) })
extra_index += 1 extra_index += 1
block._insert_op( #block._insert_op(
index=index + extra_index, # index=index + extra_index,
type='c_sync_comm_stream', # type='c_sync_comm_stream',
inputs={'X': [var]}, # inputs={'X': [var]},
outputs={'Out': [var]}, # outputs={'Out': [var]},
attrs={ # attrs={
self._op_device_key: cur_device, # self._op_device_key: cur_device,
self._op_role_key: # self._op_role_key:
core.op_proto_and_checker_maker.OpRole.Backward, # core.op_proto_and_checker_maker.OpRole.Backward,
'ring_id': self.ring_id, # 'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, # #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
}) # })
extra_index += 1 #extra_index += 1
fill_shape = list(var.shape) fill_shape = list(var.shape)
fill_shape[0] = 1 fill_shape[0] = 1
block._insert_op( block._insert_op(
...@@ -4509,6 +4542,7 @@ class PipelineOptimizer(object): ...@@ -4509,6 +4542,7 @@ class PipelineOptimizer(object):
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
self._op_device_key: cur_device, self._op_device_key: cur_device,
#self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward,
self._op_role_key: op_role, self._op_role_key: op_role,
'ring_id': self.ring_id, 'ring_id': self.ring_id,
#'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2,
...@@ -4987,6 +5021,10 @@ class PipelineOptimizer(object): ...@@ -4987,6 +5021,10 @@ class PipelineOptimizer(object):
and 'local_rank' in main_block.program._pipeline_opt, \ and 'local_rank' in main_block.program._pipeline_opt, \
'Please use pipeline with fleet.' 'Please use pipeline with fleet.'
local_rank = main_block.program._pipeline_opt['local_rank'] local_rank = main_block.program._pipeline_opt['local_rank']
schedule_mode = 0
if 'schedule_mode' in main_block.program._pipeline_opt:
schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self.schedule_mode = schedule_mode
self.use_sharding = False self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt: if 'use_sharding' in main_block.program._pipeline_opt:
...@@ -5074,6 +5112,7 @@ class PipelineOptimizer(object): ...@@ -5074,6 +5112,7 @@ class PipelineOptimizer(object):
"inner_parallelism": len(device_list), "inner_parallelism": len(device_list),
"num_pipeline_stages": len(device_list), "num_pipeline_stages": len(device_list),
"pipeline_stage": local_rank, "pipeline_stage": local_rank,
"schedule_mode": schedule_mode,
"section_program": program_list[local_rank], "section_program": program_list[local_rank],
"place": place_list[local_rank], "place": place_list[local_rank],
"place_id": place_id, "place_id": place_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册