From 9a0ef7d2aa762b33cfc9bd5145550647db83d2f2 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 14 Oct 2017 20:04:43 -0700 Subject: [PATCH] append_backward return map to python --- paddle/framework/backward.cc | 6 +++--- paddle/framework/backward.h | 11 ++++++----- paddle/pybind/protobuf.cc | 14 +++++++++++++- python/paddle/v2/framework/tests/test_program.py | 10 +++++++++- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 719ac7c80..102fe2e67 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -412,9 +412,9 @@ std::vector> MakeBlockBackward( return backward_descs; } -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars) { +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_var_names; no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index af8ad0aaa..e94bdeab7 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -36,11 +36,12 @@ struct GradVarInfo { int op_idx_; }; -// TODO(jiayi): Add target as parameter and generate backward op -// according to target. -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars); +using ParamGradInfoMap = std::unordered_map; + +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 2acfc28b6..df94647af 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) { .def("append_backward", [](ProgramDescBind &program_desc, const VarDescBind &target, const std::unordered_set &no_grad_vars) { - AppendBackward(program_desc, target, no_grad_vars); + ParamGradInfoMap param_grad_map = + AppendBackward(program_desc, target, no_grad_vars); + std::unordered_map< + std::string, std::tuple> + retv; + for (auto it = param_grad_map.begin(); it != param_grad_map.end(); + ++it) { + const auto &grad_info = it->second; + retv[it->first] = std::make_tuple( + grad_info.name_, grad_info.block_idx_, grad_info.op_idx_); + } + return retv; }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size) diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index c5674382a..cd209b058 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -57,8 +57,16 @@ class TestProgram(unittest.TestCase): "mul", "elementwise_add", "fill_constant", "elementwise_add_grad", "mul_grad" ] + + def grad_name(name): + return name + "@GRAD" + actual_ops = [] - prog.append_backward(target, set()) + param_to_grad = prog.append_backward(target, set()) + for var_name in ("x1", "y1", "out1", "b1"): + self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) + self.assertEqual(param_to_grad[var_name][1], 0) + for op in block.all_ops(): actual_ops.append(op.type()) print(actual_ops) -- GitLab