From 5ce58d570a05ee74fd664eba5df711d122b1b15e Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 16 Jul 2021 10:41:47 +0800 Subject: [PATCH] [hybrid check] improve pipeline stage check (#34193) --- python/paddle/fluid/optimizer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4aae43ccdbe..9a551144148 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4663,6 +4663,7 @@ class PipelineOptimizer(object): pre_stage_id = None decrease_flag = False in_optimize = False + in_forward = True for op in block.ops: if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( @@ -4680,6 +4681,8 @@ class PipelineOptimizer(object): valid_op_role_value) if int(op_role) == int(self._op_role.Optimize): in_optimize = True + if int(op_role) == int(self._op_role.Backward): + in_forward = False assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, @@ -4707,14 +4710,16 @@ class PipelineOptimizer(object): "but the interval of op={} and prev op is {}".format(op, interval) # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error - if interval == -1: - decrease_flag = True - if interval == 1: - # FIXME(wangxi): recompute failed + if in_forward: + assert interval >= 0, \ + "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \ + "please check the stage of op={}".format(pre_stage_id, op) + else: + # FIXME(wangxi): recompute check failed pass - #assert decrease_flag is False, \ - # "Pipeline stage must be in order, " \ - # "please check the stage of op={}".format(op) + #assert interval <=0, \ + # "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \ + # "please check the stage of op={}".format(pre_stage_id, op) pre_stage_id = stage_id return device_list -- GitLab