From 9935fdd3dd92cf9930f88b070090925d2909ed1a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Oct 2017 14:14:42 -0700 Subject: [PATCH] Update --- paddle/framework/backward.cc | 57 ++++++++++++++++++++++++++++++++--- paddle/framework/block_desc.h | 4 +++ paddle/framework/op_desc.cc | 16 ++++++++++ paddle/framework/op_desc.h | 4 +++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 1b4c5c025e..0f65478ef8 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -235,14 +235,17 @@ static bool AllGradInSet(const std::vector& names, } std::vector CreatBackwardOps( - const OpDescBind& op_desc, unordered_map& no_grad_vars) { + const std::unique_ptr& op_desc_ptr, + unordered_map& no_grad_vars) { + const OpDescBind& op_desc = *op_desc_ptr; std::vector grad_op_descs; // All input gradients of forwarding operator do not need to calculat. - if (AllGradInSet(op_desc_.InputNames(), kGradVarSuffix, no_grad_vars)) { + if (AllGradInSet(op_desc_.InputArgumentNames(), kGradVarSuffix, + no_grad_vars)) { return grad_op_descs; // empty vector } // All output gradients of forwarding operator do not need to calculate. - const std::vector& outputs = op_desc_.OutputNames(); + const std::vector& outputs = op_desc_.OutputArugumentNames(); if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { for (const std::string& name : outputs) { no_grad_vars.insert(GradVarName(name)); @@ -254,7 +257,7 @@ std::vector CreatBackwardOps( std::vector fill_zeros_ops; for (OpDescBind& desc : grad_op_descs) { - for (const std::string& in_name : desc.InputNames()) { + for (const std::string& in_name : desc.InputArgumentNames()) { if (no_grad_vars.count(in_name)) { std::string prefix = in_name.substr( 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); @@ -278,5 +281,51 @@ std::vector CreatBackwardOps( return grad_op_descs; } +void AppendBackwardOps(BlockDescBind& block_desc, + const std::unordered_set& no_grad_vars) { + std::unordered_map> dup_out_ops; + size_t grad_desc_idx = 0; + std::deque> op_descs = block_desc.ops_; + std::vector> grad_op_descs; + for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { + std::vector op_grads = CreatBackwardOps(*it, no_grad_vars); + for (const OpDescBind& desc : op_grads) { + for (const std::string& out_name : desc.OutputArugumentNames()) { + dup_out_ops[out_name].emplace_back(grad_desc_idx); + } + ++grad_desc_idx; + } + grad_op_descs.insert(grad_op_descs.end(), op_grads.begin(), op_grads.end()); + } + // Check whether some variables are written more than once + std::list> pending_sum_ops; + for (const auto& dup : dup_out_ops) { + const std::string& out_name = dup.first; + const std::vector dup_op = dup.second; + if (out_name != kEmptyVarName && dup_op.size() > 1) { + std::vector sum_op_inputs; + for (size_t i = 0; i < dup_op.size(); ++i) { + std::string new_name = out_name + "@RENAME@" + std::to_string(i); + grad_op_descs[dup_op[i]].Rename(out_name, new_name); + sum_op_inputs.emplace_back(new_name); + } + pending_sum_ops.push_back( + {dup_op.back(), + OpDescBind( + {"sum", {{"X", {sum_op_inputs}}}, {{"Out", {out_name}}}, {}})}); + } + } + pending_sum_ops.sort( + [](const std::pair& a, + const std::pair& b) { return a.first > b.first; }); + for (auto& p : pending_sum_ops) { + grad_op_descs.insert(grad_op_descs.begin() + p.first + 1, + std::move(p.second)); + } + // Append grad_op_descs to BlockDescBind::ops_ + for () { + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 59513ede33..a171dfef30 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -32,6 +32,10 @@ class ProgramDescBind; class BlockDescBind { public: + friend void AppendBackwardOps( + BlockDescBind &block_desc, + const std::unordered_set &no_grad_vars); + BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index f2e0c14fbd..e6c0cdacd9 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -49,6 +49,14 @@ std::vector OpDescBind::InputNames() const { return retv; } +std::vector InputArgumentNames() const { + std::vector retv; + for (auto &ipt : this->inputs_) { + retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); + } + return retv; +} + void OpDescBind::SetInput(const std::string ¶m_name, const std::vector &args) { need_update_ = true; @@ -72,6 +80,14 @@ std::vector OpDescBind::OutputNames() const { return retv; } +std::vector OutputArgumentNames() const { + std::vector retv; + for (auto &ipt : this->outputs_) { + retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); + } + return retv; +} + void OpDescBind::SetOutput(const std::string ¶m_name, const std::vector &args) { need_update_ = true; diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 2280654812..e30c58632e 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -42,6 +42,8 @@ class OpDescBind { std::vector InputNames() const; + std::vector InputArgumentNames() const; + void SetInput(const std::string ¶m_name, const std::vector &args); @@ -49,6 +51,8 @@ class OpDescBind { std::vector OutputNames() const; + std::vector OutputArgumentNames() const; + void SetOutput(const std::string ¶m_name, const std::vector &args); -- GitLab