diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 321483833e114f4a7938a9d0dac1416ae2f91a98..719ac7c80a8c57d6a937e05fa0aefa8ba889ecde 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -273,18 +273,30 @@ static bool AllGradInSet(const std::vector& names, return true; } -static void CreateGradVarInBlock(BlockDescBind* block_desc, - size_t grad_op_start_index) { +static void CreateGradVarInBlock( + std::unordered_map* grad_var_record, + BlockDescBind* block_desc, size_t grad_op_start_index, + const std::unordered_map& param_name_map) { auto ops = block_desc->AllOps(); for (size_t op_index = grad_op_start_index; op_index < ops.size(); ++op_index) { - for (const auto& output : ops[op_index]->Outputs()) { - for (const auto& real_output : output.second) { - if (!block_desc->HasVar(real_output)) { - block_desc->NewVar(real_output); - } - } - } + ForEachVarName(ops[op_index]->Outputs(), + [&](const std::string& grad_var_name) { + if (block_desc->HasVar(grad_var_name)) { + return false; + } + block_desc->NewVar(grad_var_name); + auto it = param_name_map.find(grad_var_name); + if (it == param_name_map.end()) { + return false; + } + auto param_var_name = it->second; + auto& grad_record = (*grad_var_record)[param_var_name]; + grad_record.name_ = grad_var_name; + grad_record.block_idx_ = block_desc->ID(); + grad_record.op_idx_ = static_cast(op_index); + return false; /* not break */ + }); } } @@ -400,8 +412,9 @@ std::vector> MakeBlockBackward( return backward_descs; } -void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars) { +std::unordered_map +AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_var_names; no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); @@ -423,20 +436,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, all_ops.push_back(std::move(fill_one_op)); size_t forward_op_num = all_ops.size(); size_t forward_block_num = program_desc.Size(); + + // Insert backward operators std::unordered_map grad_to_var; auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx, &no_grad_var_names, &grad_to_var); + + std::unordered_map retv; + + // Create Variable for (auto& ptr : backward_op_descs) { all_ops.push_back(std::move(ptr)); } root_block->NewVar(fill_one_op_out); // create grad_var for all blocks in this program - CreateGradVarInBlock(root_block, forward_op_num); + CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); for (size_t block_index = forward_block_num; block_index < program_desc.Size(); ++block_index) { - CreateGradVarInBlock(program_desc.Block(block_index), 0); + CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, + grad_to_var); } + return retv; } } // namespace framework diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 2c95d18ef7e2d997679bff442bf89d6364eb13ea..af8ad0aaa46728753273f7e700bc437f6412132f 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -14,7 +14,10 @@ #pragma once +#include +#include #include + #include "paddle/framework/operator.h" #include "paddle/framework/program_desc.h" @@ -27,10 +30,17 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); +struct GradVarInfo { + std::string name_; + int block_idx_; + int op_idx_; +}; + // TODO(jiayi): Add target as parameter and generate backward op // according to target. -void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars); +std::unordered_map +AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 4446576a3cb7a7534985d85e48c56b66430eccdf..50f88ec2f7f7b5f69c3a56db370b45f628498fe6 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -33,15 +33,6 @@ class ProgramDescBind; class BlockDescBind { public: - friend std::vector> MakeBlockBackward( - ProgramDescBind &program_desc, int block_idx, - std::unordered_set *no_grad_vars, - std::unordered_map *grad_to_var); - - friend void AppendBackward( - ProgramDescBind &program_desc, const VarDescBind &target, - const std::unordered_set &no_grad_vars); - BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} @@ -69,7 +60,9 @@ class BlockDescBind { BlockDesc *Proto(); - private: + // FIXME(yuyang18): backward will access private data of BlockDesc. + // Mark it public temporary. We can fix it later. + public: ProgramDescBind *prog_; // not_own BlockDesc *desc_; // not_own bool need_update_;