diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 719ac7c80a8c57d6a937e05fa0aefa8ba889ecde..102fe2e67f33ca235451f130aaffe392c3890acd 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -412,9 +412,9 @@ std::vector> MakeBlockBackward( return backward_descs; } -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars) { +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_var_names; no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index af8ad0aaa46728753273f7e700bc437f6412132f..e94bdeab7ec9952e65186bec56e43cc411721101 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -36,11 +36,12 @@ struct GradVarInfo { int op_idx_; }; -// TODO(jiayi): Add target as parameter and generate backward op -// according to target. -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars); +using ParamGradInfoMap = std::unordered_map; + +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 2acfc28b66456c4ecf159bc6a714c939e98ecd24..df94647aff79bf850174daee8ff8a914ea2e6396 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) { .def("append_backward", [](ProgramDescBind &program_desc, const VarDescBind &target, const std::unordered_set &no_grad_vars) { - AppendBackward(program_desc, target, no_grad_vars); + ParamGradInfoMap param_grad_map = + AppendBackward(program_desc, target, no_grad_vars); + std::unordered_map< + std::string, std::tuple> + retv; + for (auto it = param_grad_map.begin(); it != param_grad_map.end(); + ++it) { + const auto &grad_info = it->second; + retv[it->first] = std::make_tuple( + grad_info.name_, grad_info.block_idx_, grad_info.op_idx_); + } + return retv; }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size) diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index c5674382a484a91268e0139ba5588b123531210e..cd209b0585fa136fda7fdd540ae3b691eacc8994 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -57,8 +57,16 @@ class TestProgram(unittest.TestCase): "mul", "elementwise_add", "fill_constant", "elementwise_add_grad", "mul_grad" ] + + def grad_name(name): + return name + "@GRAD" + actual_ops = [] - prog.append_backward(target, set()) + param_to_grad = prog.append_backward(target, set()) + for var_name in ("x1", "y1", "out1", "b1"): + self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) + self.assertEqual(param_to_grad[var_name][1], 0) + for op in block.all_ops(): actual_ops.append(op.type()) print(actual_ops)