未验证 提交 a2dbb0c2 编写于 作者: W WangXi 提交者: GitHub

[hybrid] check pipeline persist var which changed in forward and used in backward (#35453)

上级 b95c5ae0
......@@ -5715,6 +5715,35 @@ class PipelineOptimizer(object):
backward_insert_index += 1
block._sync_with_cpp()
def _check_pipeline_persist_var(self, program):
"""
Pipeline may need multiple forward before
"""
block = program.global_block()
persist_output = set()
used_in_backward = set()
for op in block.ops:
if self._is_forward_op(op):
for var_name in op.output_arg_names:
var = block.vars[var_name]
if var.persistable:
persist_output.add(var_name)
elif self._is_backward_op(op):
for var_name in op.input_arg_names:
if var_name in persist_output:
used_in_backward.add(var_name)
if len(used_in_backward) == 0:
return
warnings.warn(
"The pipeline requires multiple forward calculations before backward, "
"so when the persistable var is changed in the forward, it may cause "
"errors in the backward calculation who using this persistable var. "
"However, some backward op don't need this var(NoNeedBufferVars), "
"there will be no error at this time.\n"
"So please check these persistable vars which changed in "
"forward and used in backward:\n{}".format(used_in_backward))
def minimize(self,
loss,
startup_program=None,
......@@ -5831,6 +5860,11 @@ class PipelineOptimizer(object):
# A pass to move the recv op to the beginning of
# the forward/backward phase
self._mv_head_recv(program_list[self.local_rank])
# A pass to check pipeline persist var which changed in
# forward and used in backward
self._check_pipeline_persist_var(program_list[self.local_rank])
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
......
......@@ -45,6 +45,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
with static.device_guard("gpu:1"):
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
# for pipeline check_pipeline_persist_var coverage
fc_2.persistable = True
fc_2 = fc_2 * input_z
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册