diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index a1b1ebd4abe47c24c834408f7ffa2fb205606c47..98f8b35c3e8fd7ef01e3f97969e6c0ada4df8bc6 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -258,6 +258,8 @@ void BindValue(py::module *m) { &Value::GetDefiningOp, return_value_policy::reference) .def("first_use", &Value::first_use, return_value_policy::reference) + .def("has_one_use", &Value::HasOneUse) + .def("use_empty", &Value::use_empty) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -343,6 +345,7 @@ void BindOpResult(py::module *m) { &OpResult::GetDefiningOp, return_value_policy::reference) .def("first_use", &OpResult::first_use, return_value_policy::reference) + .def("has_one_use", &Value::HasOneUse) .def("use_empty", &OpResult::use_empty) .def("type", &OpResult::type) .def_property( diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 67fafbae389c5e42bf63c492d81bd8caddb7762b..e631c02b0bd3e03c341334c3da8653a29606d53f 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -204,6 +204,22 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): else: relevant_op_flags[i] = False + # recover full op or full_Intarray op created by mutable attribute. + total_ops_list = list(total_ops) + for i, op in enumerate(total_ops_list): + if relevant_op_flags[i] is False: + for result in op.results(): + if result.has_one_use(): + next_op = result.first_use().owner() + if ( + next_op in total_ops + and relevant_op_flags[total_ops_list.index(next_op)] + is True + ): + relevant_op_flags[i] = True + else: + continue + effective_ops = [ total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] ] diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 63e5bdbc9e4c703993b818a59745e6f722bc0a53..7adbbca86cf6116af7052602678bf89bb2259c5e 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -168,5 +168,37 @@ class TesBackward_2(unittest.TestCase): paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) +def get_ir_program_2(): + x = paddle.randn([2, 2]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + k_s = paddle.sum(x_s, axis=(-1,), keepdim=False) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TestBackward_3(unittest.TestCase): + def test_basic_network(self): + newir_program = get_ir_program_2() + x = newir_program.block().ops[-1].operand(0).source() + sum_x = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + norm = paddle.tensor.fill_constant( + shape=[], + value=1.0, + dtype=sum_x.dtype, + ) + res = paddle.divide(sum_x, norm) + input_grad = grad(res, x) + + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + if __name__ == "__main__": unittest.main()