diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c16d3e0cbe01f90a5aa9a5d7a523cd4e282e4771..1faf24bcb8828596ec37abde9e699f46526e41df 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -282,6 +282,23 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); + m.def("get_grad_op_descs", + [](const OpDescBind &op_desc, + const std::unordered_set &no_grad_set, + std::unordered_map &grad_to_var, + const std::vector &grad_sub_block) { + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, + grad_sub_block); + std::vector grad_op_desc_ptrs(grad_op_descs.size()); + std::transform( + grad_op_descs.begin(), grad_op_descs.end(), + grad_op_desc_ptrs.begin(), + [](std::unique_ptr &p) { return p.release(); }); + return grad_op_desc_ptrs; + }); m.def("prune", [](const ProgramDescBind &origin, const std::vector> &targets) { ProgramDescBind prog_with_targets(origin);