未验证 提交 a4c8d977 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【newir】modify prune_ops func for mutable attribute inside program (#56564)

* modify sum with divide net bug mutablesttribute

* delete prin
上级 bb3fb69c
......@@ -258,6 +258,8 @@ void BindValue(py::module *m) {
&Value::GetDefiningOp,
return_value_policy::reference)
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
......@@ -343,6 +345,7 @@ void BindOpResult(py::module *m) {
&OpResult::GetDefiningOp,
return_value_policy::reference)
.def("first_use", &OpResult::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &OpResult::use_empty)
.def("type", &OpResult::type)
.def_property(
......
......@@ -204,6 +204,22 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
else:
relevant_op_flags[i] = False
# recover full op or full_Intarray op created by mutable attribute.
total_ops_list = list(total_ops)
for i, op in enumerate(total_ops_list):
if relevant_op_flags[i] is False:
for result in op.results():
if result.has_one_use():
next_op = result.first_use().owner()
if (
next_op in total_ops
and relevant_op_flags[total_ops_list.index(next_op)]
is True
):
relevant_op_flags[i] = True
else:
continue
effective_ops = [
total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i]
]
......
......@@ -168,5 +168,37 @@ class TesBackward_2(unittest.TestCase):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def get_ir_program_2():
x = paddle.randn([2, 2])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
k_s = paddle.sum(x_s, axis=(-1,), keepdim=False)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TestBackward_3(unittest.TestCase):
def test_basic_network(self):
newir_program = get_ir_program_2()
x = newir_program.block().ops[-1].operand(0).source()
sum_x = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
norm = paddle.tensor.fill_constant(
shape=[],
value=1.0,
dtype=sum_x.dtype,
)
res = paddle.divide(sum_x, norm)
input_grad = grad(res, x)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册