diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 9ea4e70a26203c2ab49a306426ee89f207108060..1faf24bcb8828596ec37abde9e699f46526e41df 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -282,15 +282,22 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_grad_op_desc", + m.def("get_grad_op_descs", [](const OpDescBind &op_desc, const std::unordered_set &no_grad_set, std::unordered_map &grad_to_var, const std::vector &grad_sub_block) { - return framework::OpInfoMap::Instance() - .Get(op_desc.Type()) - .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, - grad_sub_block); + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, + grad_sub_block); + std::vector grad_op_desc_ptrs(grad_op_descs.size()); + std::transform( + grad_op_descs.begin(), grad_op_descs.end(), + grad_op_desc_ptrs.begin(), + [](std::unique_ptr &p) { return p.release(); }); + return grad_op_desc_ptrs; }); m.def("prune", [](const ProgramDescBind &origin, const std::vector> &targets) {