diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 5ebb0a5880a707c6aa541c573f4b6ea0b4eaea49..ca33a9a50c4137a19e27f510cc91f20e9e9b8449 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -273,18 +273,40 @@ static bool AllGradInSet(const std::vector& names, return true; } -static void CreateGradVarInBlock(BlockDescBind* block_desc, - size_t grad_op_start_index) { +static void CreateGradVarInBlock( + std::unordered_map* grad_var_record, + BlockDescBind* block_desc, size_t grad_op_start_index, + const std::unordered_map& param_name_map) { auto ops = block_desc->AllOps(); for (size_t op_index = grad_op_start_index; op_index < ops.size(); ++op_index) { - for (const auto& output : ops[op_index]->Outputs()) { - for (const auto& real_output : output.second) { - if (!block_desc->HasVar(real_output)) { - block_desc->Var(real_output); - } - } - } + // <<<<<<< HEAD + // for (const auto& output : ops[op_index]->Outputs()) { + // for (const auto& real_output : output.second) { + // if (!block_desc->HasVar(real_output)) { + // block_desc->Var(real_output); + // } + // } + // } + // ======= + ForEachVarName(ops[op_index]->Outputs(), + [&](const std::string& grad_var_name) { + if (block_desc->HasVar(grad_var_name)) { + return false; + } + block_desc->Var(grad_var_name); + auto it = param_name_map.find(grad_var_name); + if (it == param_name_map.end()) { + return false; + } + auto param_var_name = it->second; + auto& grad_record = (*grad_var_record)[param_var_name]; + grad_record.name_ = grad_var_name; + grad_record.block_idx_ = block_desc->ID(); + grad_record.op_idx_ = static_cast(op_index); + return false; /* not break */ + }); + // >>>>>>> origin/develop } } @@ -400,8 +422,9 @@ std::vector> MakeBlockBackward( return backward_descs; } -void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars) { +std::unordered_map +AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_var_names; no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); @@ -423,20 +446,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, all_ops.push_back(std::move(fill_one_op)); size_t forward_op_num = all_ops.size(); size_t forward_block_num = program_desc.Size(); + + // Insert backward operators std::unordered_map grad_to_var; auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx, &no_grad_var_names, &grad_to_var); + + std::unordered_map retv; + + // Create Variable for (auto& ptr : backward_op_descs) { all_ops.push_back(std::move(ptr)); } root_block->Var(fill_one_op_out); // create grad_var for all blocks in this program - CreateGradVarInBlock(root_block, forward_op_num); + CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var); for (size_t block_index = forward_block_num; block_index < program_desc.Size(); ++block_index) { - CreateGradVarInBlock(program_desc.Block(block_index), 0); + CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0, + grad_to_var); } + return retv; } } // namespace framework diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 2c95d18ef7e2d997679bff442bf89d6364eb13ea..af8ad0aaa46728753273f7e700bc437f6412132f 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -14,7 +14,10 @@ #pragma once +#include +#include #include + #include "paddle/framework/operator.h" #include "paddle/framework/program_desc.h" @@ -27,10 +30,17 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); +struct GradVarInfo { + std::string name_; + int block_idx_; + int op_idx_; +}; + // TODO(jiayi): Add target as parameter and generate backward op // according to target. -void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, - const std::unordered_set& no_grad_vars); +std::unordered_map +AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target, + const std::unordered_set& no_grad_vars); } // namespace framework } // namespace paddle diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index e1424c4bdec338dc2fe085c464c13018d2324ed6..9fb88f963283c72e1ec389b72dd2d98049c74f6d 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -33,15 +33,6 @@ class ProgramDescBind; class BlockDescBind { public: - friend std::vector> MakeBlockBackward( - ProgramDescBind &program_desc, int block_idx, - std::unordered_set *no_grad_vars, - std::unordered_map *grad_to_var); - - friend void AppendBackward( - ProgramDescBind &program_desc, const VarDescBind &target, - const std::unordered_set &no_grad_vars); - BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} @@ -69,7 +60,9 @@ class BlockDescBind { BlockDesc *Proto(); - private: + // FIXME(yuyang18): backward will access private data of BlockDesc. + // Mark it public temporary. We can fix it later. + public: ProgramDescBind *prog_; // not_own BlockDesc *desc_; // not_own bool need_update_;