From 044a13d02262f1cf84ee685a0575cf2ec28e5623 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 14 Dec 2017 17:50:56 +0800 Subject: [PATCH] expose GradOpMaker to Python --- paddle/pybind/pybind.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 9ea4e70a26..1faf24bcb8 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -282,15 +282,22 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_grad_op_desc", + m.def("get_grad_op_descs", [](const OpDescBind &op_desc, const std::unordered_set &no_grad_set, std::unordered_map &grad_to_var, const std::vector &grad_sub_block) { - return framework::OpInfoMap::Instance() - .Get(op_desc.Type()) - .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, - grad_sub_block); + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, + grad_sub_block); + std::vector grad_op_desc_ptrs(grad_op_descs.size()); + std::transform( + grad_op_descs.begin(), grad_op_descs.end(), + grad_op_desc_ptrs.begin(), + [](std::unique_ptr &p) { return p.release(); }); + return grad_op_desc_ptrs; }); m.def("prune", [](const ProgramDescBind &origin, const std::vector> &targets) { -- GitLab