未验证 提交 5ce58d57 编写于 作者: W WangXi 提交者: GitHub

[hybrid check] improve pipeline stage check (#34193)

上级 4e5cb7d8
...@@ -4663,6 +4663,7 @@ class PipelineOptimizer(object): ...@@ -4663,6 +4663,7 @@ class PipelineOptimizer(object):
pre_stage_id = None pre_stage_id = None
decrease_flag = False decrease_flag = False
in_optimize = False in_optimize = False
in_forward = True
for op in block.ops: for op in block.ops:
if not op._has_kernel(op.type): if not op._has_kernel(op.type):
assert op.type == "conditional_block" and ( assert op.type == "conditional_block" and (
...@@ -4680,6 +4681,8 @@ class PipelineOptimizer(object): ...@@ -4680,6 +4681,8 @@ class PipelineOptimizer(object):
valid_op_role_value) valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize): if int(op_role) == int(self._op_role.Optimize):
in_optimize = True in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False
assert op.has_attr(self._op_device_key), ( assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type, "op ({}) has no {} attribute.".format(op.type,
...@@ -4707,14 +4710,16 @@ class PipelineOptimizer(object): ...@@ -4707,14 +4710,16 @@ class PipelineOptimizer(object):
"but the interval of op={} and prev op is {}".format(op, interval) "but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1: if in_forward:
decrease_flag = True assert interval >= 0, \
if interval == 1: "Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
# FIXME(wangxi): recompute failed "please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass pass
#assert decrease_flag is False, \ #assert interval <=0, \
# "Pipeline stage must be in order, " \ # "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(op) # "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id pre_stage_id = stage_id
return device_list return device_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册