未验证 提交 7a633e64 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【new ir】modify test comp divide_grad (#56697)

* modify test comp grad

* modify test comp grad
上级 ad93dc0c
...@@ -179,51 +179,41 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ...@@ -179,51 +179,41 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
pruned op in total_ops is uneffective_ops, else is effective_ops pruned op in total_ops is uneffective_ops, else is effective_ops
''' '''
relevant_op_flags = [True] * len(total_ops) intersection_op_flags = [True] * len(total_ops)
union_op_flags = [False] * len(total_ops)
# from input to output # from input to output
if inputs_set: if inputs_set:
for i, op in enumerate(total_ops): for i, op in enumerate(total_ops):
if some_in_set(op.results(), inputs_set): if some_in_set(op.results(), inputs_set):
union_op_flags[i] = True
continue continue
if some_in_set(op.operands_source(), inputs_set): if some_in_set(op.operands_source(), inputs_set):
union_op_flags[i] = True
for value in op.results(): for value in op.results():
if value not in no_grad_set: if value not in no_grad_set:
inputs_set.add(value) inputs_set.add(value)
else: else:
relevant_op_flags[i] = False intersection_op_flags[i] = False
# from output to input # from output to input
for i, op in reversed(list(enumerate(total_ops))): for i, op in reversed(list(enumerate(total_ops))):
if some_in_set(op.results(), outputs_set): if some_in_set(op.results(), outputs_set):
union_op_flags[i] = True
for operand in op.operands_source(): for operand in op.operands_source():
if operand not in no_grad_set: if operand not in no_grad_set:
outputs_set.add(operand) outputs_set.add(operand)
else: else:
relevant_op_flags[i] = False union_op_flags[i] = False
# recover full op or full_Intarray op created by mutable attribute. intersection_op_flags[i] = False
total_ops_list = list(total_ops)
for i, op in enumerate(total_ops_list):
if relevant_op_flags[i] is False:
for result in op.results():
if result.has_one_use():
next_op = result.first_use().owner()
if (
next_op in total_ops
and relevant_op_flags[total_ops_list.index(next_op)]
is True
):
relevant_op_flags[i] = True
else:
continue
effective_ops = [ effective_ops = [
total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] total_ops[i] for i in range(len(total_ops)) if intersection_op_flags[i]
] ]
uneffective_ops = [ uneffective_ops = [
total_ops[i] total_ops[i]
for i in reversed(range(len(total_ops))) for i in reversed(range(len(total_ops)))
if not relevant_op_flags[i] if not union_op_flags[i]
] ]
return effective_ops, uneffective_ops return effective_ops, uneffective_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册