diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4aae43ccdbe3d6d58fcdcf435309a41dcde8bcd0..9a5511441481513f5c96729227c55c95953a297e 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4663,6 +4663,7 @@ class PipelineOptimizer(object): pre_stage_id = None decrease_flag = False in_optimize = False + in_forward = True for op in block.ops: if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( @@ -4680,6 +4681,8 @@ class PipelineOptimizer(object): valid_op_role_value) if int(op_role) == int(self._op_role.Optimize): in_optimize = True + if int(op_role) == int(self._op_role.Backward): + in_forward = False assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, @@ -4707,14 +4710,16 @@ class PipelineOptimizer(object): "but the interval of op={} and prev op is {}".format(op, interval) # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error - if interval == -1: - decrease_flag = True - if interval == 1: - # FIXME(wangxi): recompute failed + if in_forward: + assert interval >= 0, \ + "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \ + "please check the stage of op={}".format(pre_stage_id, op) + else: + # FIXME(wangxi): recompute check failed pass - #assert decrease_flag is False, \ - # "Pipeline stage must be in order, " \ - # "please check the stage of op={}".format(op) + #assert interval <=0, \ + # "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \ + # "please check the stage of op={}".format(pre_stage_id, op) pre_stage_id = stage_id return device_list