diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c16d3e0cbe01f90a5aa9a5d7a523cd4e282e4771..9ea4e70a26203c2ab49a306426ee89f207108060 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -282,6 +282,16 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); + m.def("get_grad_op_desc", + [](const OpDescBind &op_desc, + const std::unordered_set &no_grad_set, + std::unordered_map &grad_to_var, + const std::vector &grad_sub_block) { + return framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, + grad_sub_block); + }); m.def("prune", [](const ProgramDescBind &origin, const std::vector> &targets) { ProgramDescBind prog_with_targets(origin);