diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c4ede7d2fba38036cad488b7235d57fb950ec755..d9a42be5a2d880cbe1a57ebac858dd5078f7c964 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -20,6 +20,7 @@ #include "paddle/framework/block_desc.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/program_desc.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -254,7 +255,7 @@ static bool AllGradInSet(const std::vector& names, return true; } -std::vector> MakeGradOpDescs( +std::vector> MakeOpGrad( const std::unique_ptr& op_desc, std::unordered_set& no_grad_vars) { std::vector> grad_op_descs; @@ -295,20 +296,35 @@ std::vector> MakeGradOpDescs( for (auto& p : pending_fill_zeros_ops) { grad_op_descs.push_back(std::move(p)); } - - // TODO(fengjiayi): RNN op return grad_op_descs; } -void AppendBackwardOpDescs(BlockDescBind& block_desc, - std::unordered_set& no_grad_vars) { +std::vector> MakeBlockBackward( + ProgramDescBind& program_desc, int block_idx, + std::unordered_set& no_grad_vars) { + BlockDescBind* cur_block = program_desc.Block(block_idx); + std::deque>& op_descs = cur_block->ops_; std::unordered_map> dup_out_ops; size_t grad_desc_idx = 0; - std::deque>& block_op_descs = block_desc.ops_; std::vector> backward_descs; - for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) { + for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { std::vector> op_grads = - MakeGradOpDescs(*it, no_grad_vars); + MakeOpGrad(*it, no_grad_vars); + + if ((*it)->Type() == "recurrent") { + PADDLE_ENFORCE_EQ( + op_grads.size(), size_t(1), + "rnn_op's gradient process should contain only one op."); + int step_block_idx = (*it)->GetBlockAttr("stop_block"); + auto backward_block_op_descs = + MakeBlockBackward(program_desc, step_block_idx, no_grad_vars); + BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); + for (auto& ptr : backward_block_op_descs) { + backward_block->ops_.push_back(std::move(ptr)); + } + op_grads[0]->SetBlockAttr("step_block", *backward_block); + } + for (const auto& desc : op_grads) { for (const std::string& out_name : desc->OutputArgumentNames()) { dup_out_ops[out_name].emplace_back(grad_desc_idx); @@ -345,11 +361,24 @@ void AppendBackwardOpDescs(BlockDescBind& block_desc, backward_descs.insert(backward_descs.begin() + p.first + 1, std::move(p.second)); } - // Append backward_descs to BlockDescBind::ops_ - for (std::unique_ptr& ptr : backward_descs) { - block_op_descs.push_back(std::move(ptr)); + return backward_descs; +} + +void AppendBackward(ProgramDescBind& program_desc, + const std::unordered_set& no_grad_vars) { + std::unordered_set no_grad_var_names; + no_grad_var_names.reserve(no_grad_vars.size() + 1); + no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); + for (auto& name : no_grad_vars) { + no_grad_var_names.insert(GradVarName(name)); + } + const int root_block_idx = 0; + auto backward_op_descs = + MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names); + auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_; + for (auto& ptr : backward_op_descs) { + forw_op_descs.push_back(std::move(ptr)); } - return; } } // namespace framework diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index fd95ef19012754315de2e36648291b65abb8a4d8..aad1c3fef820319abc41283b2ec138be59b421a9 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -32,8 +32,13 @@ class ProgramDescBind; class BlockDescBind { public: - friend void AppendBackwardOpDescs( - BlockDescBind &block_desc, std::unordered_set &no_grad_vars); + friend std::vector> MakeBlockBackward( + ProgramDescBind &program_desc, int block_idx, + std::unordered_set &no_grad_vars); + + friend void AppendBackward( + ProgramDescBind &program_desc, + const std::unordered_set &no_grad_vars); BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {}