From a4c8d9774aaf8c5dd8e304645fc32055bc312323 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Thu, 24 Aug 2023 13:11:03 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90newir=E3=80=91modify=20prune=5Fops=20f?= =?UTF-8?q?unc=20for=20mutable=20attribute=20inside=20program=20(#56564)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * modify sum with divide net bug mutablesttribute * delete prin --- paddle/fluid/pybind/ir.cc | 3 +++ python/paddle/autograd/backward.py | 16 +++++++++++++++ test/ir/new_ir/test_ir_backward.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index a1b1ebd4abe..98f8b35c3e8 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -258,6 +258,8 @@ void BindValue(py::module *m) { &Value::GetDefiningOp, return_value_policy::reference) .def("first_use", &Value::first_use, return_value_policy::reference) + .def("has_one_use", &Value::HasOneUse) + .def("use_empty", &Value::use_empty) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -343,6 +345,7 @@ void BindOpResult(py::module *m) { &OpResult::GetDefiningOp, return_value_policy::reference) .def("first_use", &OpResult::first_use, return_value_policy::reference) + .def("has_one_use", &Value::HasOneUse) .def("use_empty", &OpResult::use_empty) .def("type", &OpResult::type) .def_property( diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 67fafbae389..e631c02b0bd 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -204,6 +204,22 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): else: relevant_op_flags[i] = False + # recover full op or full_Intarray op created by mutable attribute. + total_ops_list = list(total_ops) + for i, op in enumerate(total_ops_list): + if relevant_op_flags[i] is False: + for result in op.results(): + if result.has_one_use(): + next_op = result.first_use().owner() + if ( + next_op in total_ops + and relevant_op_flags[total_ops_list.index(next_op)] + is True + ): + relevant_op_flags[i] = True + else: + continue + effective_ops = [ total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] ] diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 63e5bdbc9e4..7adbbca86cf 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -168,5 +168,37 @@ class TesBackward_2(unittest.TestCase): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) +def get_ir_program_2(): + x = paddle.randn([2, 2]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + k_s = paddle.sum(x_s, axis=(-1,), keepdim=False) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestBackward_3(unittest.TestCase): + def test_basic_network(self): + newir_program = get_ir_program_2() + x = newir_program.block().ops[-1].operand(0).source() + sum_x = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + norm = paddle.tensor.fill_constant( + shape=[], + value=1.0, + dtype=sum_x.dtype, + ) + res = paddle.divide(sum_x, norm) + input_grad = grad(res, x) + + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + if __name__ == "__main__": unittest.main() -- GitLab