未验证 提交 cc007dce 编写于 作者: L lilong12 提交者: GitHub

move the recv op the beginning of the forward/backward phase for pipeline (#34197)

* mv recv to head, test=develop
上级 9c7f6af5
......@@ -5280,6 +5280,55 @@ class PipelineOptimizer(object):
attrs={self._op_role_key: self._op_role.Backward})
block._sync_with_cpp()
def _mv_head_recv(self, program):
"""
A pass to move the recv op to the beginning of
the forward/backward phase
"""
forward_insert_index = 0
backward_insert_index = None
block = program.global_block()
num_ops = len(program.global_block().ops)
for i in range(num_ops):
insert_index = None
op = program.global_block().ops[i]
op_role = int(op.attr(self._op_role_key))
if op_role == int(
self._op_role.Backward) and backward_insert_index is None:
backward_insert_index = i
if op.type != "partial_recv" and op.type != "partial_allgather" and op.type != "nop" and op.type != "recv_v2":
continue
if op_role == int(self._op_role.Forward):
if i == forward_insert_index:
forward_insert_index += 1
continue
insert_index = forward_insert_index
elif op_role == int(self._op_role.Backward):
if i == backward_insert_index:
backward_insert_index += 1
continue
insert_index = backward_insert_index
else:
raise ValueError("Unknown op_role: {}".format(op_role))
op_inputs = dict()
for name in op.input_names:
op_inputs[name] = op.input(name)
op_outputs = dict()
for name in op.output_names:
op_outputs[name] = op.output(name)
block._insert_op_without_sync(
index=insert_index,
type=op.type,
inputs=op_inputs,
outputs=op_outputs,
attrs=op.all_attrs())
block._remove_op(i + 1)
if op_role == int(self._op_role.Forward):
forward_insert_index += 1
elif op_role == int(self._op_role.Backward):
backward_insert_index += 1
block._sync_with_cpp()
def minimize(self,
loss,
startup_program=None,
......@@ -5393,6 +5442,9 @@ class PipelineOptimizer(object):
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
elif core.is_compiled_with_npu():
place_id = int(os.getenv("FLAGS_selected_npus", "0"))
# A pass to move the recv op to the beginning of
# the forward/backward phase
self._mv_head_recv(program_list[self.local_rank])
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
......
......@@ -144,6 +144,7 @@ class TestDistRunnerBase(object):
loss = loss[0] if loss else None
out_losses.append(loss)
print_to_err(type(self).__name__, "run step %d finished" % i)
data_loader.reset()
print_to_err(type(self).__name__, "trainer run finished")
sys.stdout.buffer.write(pickle.dumps(out_losses))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册