未验证 提交 5ce58d57 编写于 作者: W WangXi 提交者: GitHub

[hybrid check] improve pipeline stage check (#34193)

上级 4e5cb7d8
......@@ -4663,6 +4663,7 @@ class PipelineOptimizer(object):
pre_stage_id = None
decrease_flag = False
in_optimize = False
in_forward = True
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
......@@ -4680,6 +4681,8 @@ class PipelineOptimizer(object):
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False
assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
......@@ -4707,14 +4710,16 @@ class PipelineOptimizer(object):
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1:
decrease_flag = True
if interval == 1:
# FIXME(wangxi): recompute failed
if in_forward:
assert interval >= 0, \
"Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
"please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass
#assert decrease_flag is False, \
# "Pipeline stage must be in order, " \
# "please check the stage of op={}".format(op)
#assert interval <=0, \
# "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id
return device_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册