diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index ca33a9a50c4137a19e27f510cc91f20e9e9b8449..ca9163c037381fc015f992a0bbc8f484e33de14d 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector& names, } static void CreateGradVarInBlock( - std::unordered_map* grad_var_record, - BlockDescBind* block_desc, size_t grad_op_start_index, - const std::unordered_map& param_name_map) { + size_t grad_op_start_index, + const std::unordered_map& param_name_map, + BlockDescBind* block_desc, + std::unordered_map* grad_var_record) { auto ops = block_desc->AllOps(); for (size_t op_index = grad_op_start_index; op_index < ops.size(); ++op_index) { @@ -422,9 +423,9 @@ std::vector> MakeBlockBackward( return backward_descs; } -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars) { +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_var_names; no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); @@ -461,11 +462,11 @@ AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, root_block->Var(fill_one_op_out); // create grad_var for all blocks in this program - CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); + CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv); for (size_t block_index = forward_block_num; block_index < program_desc.Size(); ++block_index) { - CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, - grad_to_var); + CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index), + &retv); } return retv; } diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index af8ad0aaa46728753273f7e700bc437f6412132f..e94bdeab7ec9952e65186bec56e43cc411721101 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -36,11 +36,12 @@ struct GradVarInfo { int op_idx_; }; -// TODO(jiayi): Add target as parameter and generate backward op -// according to target. -std::unordered_map -AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars); +using ParamGradInfoMap = std::unordered_map; + +ParamGradInfoMap AppendBackward( + ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index b6327f8500bbbb66575d3bc928b38ab208296a44..b360b05d16c9a1c135fa56cb37919dece8f16788 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -120,7 +120,19 @@ void BindProgramDesc(py::module &m) { .def("append_backward", [](ProgramDescBind &program_desc, const VarDescBind &target, const std::unordered_set &no_grad_vars) { - AppendBackward(program_desc, target, no_grad_vars); + ParamGradInfoMap param_grad_map = + AppendBackward(program_desc, target, no_grad_vars); + std::unordered_map< + std::string, std::tuple> + retv; + for (auto it = param_grad_map.begin(); it != param_grad_map.end(); + ++it) { + const auto &grad_info = it->second; + retv[it->first] = std::make_tuple( + grad_info.name_, grad_info.block_idx_, grad_info.op_idx_); + } + return retv; }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size) diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index 07473d17f76b724b35c49b1a713beeb30d251088..7c521cd634ca570ab282b83a3536c64808332cea 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -57,11 +57,18 @@ class TestProgram(unittest.TestCase): "mul", "elementwise_add", "fill_constant", "elementwise_add_grad", "mul_grad" ] + + def grad_name(name): + return name + "@GRAD" + actual_ops = [] - prog.append_backward(target, set()) + param_to_grad = prog.append_backward(target, set()) + for var_name in ("x1", "y1", "out1", "b1"): + self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) + self.assertEqual(param_to_grad[var_name][1], 0) + for op in block.all_ops(): actual_ops.append(op.type()) - print(actual_ops) self.assertEqual(actual_ops, expect_ops)