From e47770bd27a20b2fa9bf2754d16b0e71008185e5 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Oct 2017 15:30:21 -0700 Subject: [PATCH] Update --- paddle/framework/backward.cc | 85 +++++++++++++++++---------------- paddle/framework/backward.h | 2 +- paddle/framework/op_desc.cc | 13 ++--- paddle/framework/op_registry.cc | 3 +- paddle/framework/op_registry.h | 3 +- 5 files changed, 57 insertions(+), 49 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0f65478ef81..b4eb89e2d7b 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -234,18 +234,17 @@ static bool AllGradInSet(const std::vector& names, return true; } -std::vector CreatBackwardOps( - const std::unique_ptr& op_desc_ptr, - unordered_map& no_grad_vars) { - const OpDescBind& op_desc = *op_desc_ptr; - std::vector grad_op_descs; +std::vector> MakeGradOpDescs( + const std::unique_ptr& op_desc, + unordered_set& no_grad_vars) { + std::vector> grad_op_descs; // All input gradients of forwarding operator do not need to calculat. - if (AllGradInSet(op_desc_.InputArgumentNames(), kGradVarSuffix, + if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix, no_grad_vars)) { return grad_op_descs; // empty vector } // All output gradients of forwarding operator do not need to calculate. - const std::vector& outputs = op_desc_.OutputArugumentNames(); + const std::vector& outputs = op_desc->OutputArugumentNames(); if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { for (const std::string& name : outputs) { no_grad_vars.insert(GradVarName(name)); @@ -255,50 +254,54 @@ std::vector CreatBackwardOps( grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc); - std::vector fill_zeros_ops; - for (OpDescBind& desc : grad_op_descs) { - for (const std::string& in_name : desc.InputArgumentNames()) { + std::list> pending_fill_zeros_ops; + for (auto& desc : grad_op_descs) { + for (const std::string& in_name : desc->InputArgumentNames()) { if (no_grad_vars.count(in_name)) { std::string prefix = in_name.substr( 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); std::string new_name = prefix + kZeroVarSuffix; - desc.Rename(in_name, new_name); - OpDescBind op_desc_bind( - {"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}}); - fill_zeros_ops.push_back(op_desc_bind); + desc->Rename(in_name, new_name); + OpDescBind* fill_zeros_op = new OpDescBind( + "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}); + pending_fill_zeros_ops.push_back({fill_zeros_op}); } } - for (const std::string& out_name : desc.OutputName()) { + for (const std::string& out_name : desc->OutputArgumentName()) { if (no_grad_vars.count(out_name)) { - desc.Rename(out_name, kEmptyVarName); + desc->Rename(out_name, kEmptyVarName); } } } - grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(), - fill_zeros_ops.end()); + grad_op_descs.insert(std::begin(grad_op_descs), + std::begin(pending_fill_zeros_ops), + std::end(pending_fill_zeros_ops)); // TODO (fengjiayi): RNN op return grad_op_descs; } -void AppendBackwardOps(BlockDescBind& block_desc, - const std::unordered_set& no_grad_vars) { +void AppendBackwardOpDescs( + BlockDescBind& block_desc, + const std::unordered_set& no_grad_vars) { std::unordered_map> dup_out_ops; size_t grad_desc_idx = 0; - std::deque> op_descs = block_desc.ops_; - std::vector> grad_op_descs; - for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { - std::vector op_grads = CreatBackwardOps(*it, no_grad_vars); - for (const OpDescBind& desc : op_grads) { - for (const std::string& out_name : desc.OutputArugumentNames()) { + std::deque> block_op_descs = block_desc.ops_; + std::vector> backward_descs; + for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) { + std::vector> op_grads = + MakeGradOpDescs(*it, no_grad_vars); + for (const auto& desc : op_grads) { + for (const std::string& out_name : desc->OutputArugumentNames()) { dup_out_ops[out_name].emplace_back(grad_desc_idx); } ++grad_desc_idx; } - grad_op_descs.insert(grad_op_descs.end(), op_grads.begin(), op_grads.end()); + backward_descs.insert(backward_descs.end(), op_grads.begin(), + op_grads.end()); } // Check whether some variables are written more than once - std::list> pending_sum_ops; + std::list>> pending_sum_ops; for (const auto& dup : dup_out_ops) { const std::string& out_name = dup.first; const std::vector dup_op = dup.second; @@ -306,25 +309,27 @@ void AppendBackwardOps(BlockDescBind& block_desc, std::vector sum_op_inputs; for (size_t i = 0; i < dup_op.size(); ++i) { std::string new_name = out_name + "@RENAME@" + std::to_string(i); - grad_op_descs[dup_op[i]].Rename(out_name, new_name); + backward_descs[dup_op[i]]->Rename(out_name, new_name); sum_op_inputs.emplace_back(new_name); } - pending_sum_ops.push_back( - {dup_op.back(), - OpDescBind( - {"sum", {{"X", {sum_op_inputs}}}, {{"Out", {out_name}}}, {}})}); + OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}}, + {{"Out", {out_name}}}, {}); + pending_sum_ops.push_back({dup_op.back(), {sum_op}}); } } pending_sum_ops.sort( - [](const std::pair& a, - const std::pair& b) { return a.first > b.first; }); + [](const std::pair>& a, + const std::pair>& b) { + return a.first > b.first; + }); for (auto& p : pending_sum_ops) { - grad_op_descs.insert(grad_op_descs.begin() + p.first + 1, - std::move(p.second)); - } - // Append grad_op_descs to BlockDescBind::ops_ - for () { + backward_descs.insert(backward_descs.begin() + p.first + 1, + std::move(p.second)); } + // Append backward_descs to BlockDescBind::ops_ + block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs), + std::end(backward_descs)); + return; } } // namespace framework diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 6aeddafb41e..fb496c34c7d 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -24,7 +24,7 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); -extern void AppendBackwardOps( +extern void AppendBackwardOpDescs( BlockDescBind& block_desc, const std::unordered_set& no_grad_vars); diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index e6c0cdacd96..2c6aec717bd 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -49,7 +49,7 @@ std::vector OpDescBind::InputNames() const { return retv; } -std::vector InputArgumentNames() const { +std::vector OpDescBind::InputArgumentNames() const { std::vector retv; for (auto &ipt : this->inputs_) { retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); @@ -80,7 +80,7 @@ std::vector OpDescBind::OutputNames() const { return retv; } -std::vector OutputArgumentNames() const { +std::vector OpDescBind::OutputArgumentNames() const { std::vector retv; for (auto &ipt : this->outputs_) { retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); @@ -137,12 +137,13 @@ const std::unordered_map &OpDescBind::GetAttrMap() return attrs_; } -void Rename(const std::string &old_name, const std::string &new_name) { - for (std : string &input : inputs_) { +void OpDescBind::Rename(const std::string &old_name, + const std::string &new_name) { + for (auto &input : inputs_) { std::replace(input.second.begin(), input.second.end(), old_name, new_name); } - for (std::string &output : outputs_) { - std::repalce(output.second.begin(), output.second.end(), old_name, + for (auto &output : outputs_) { + std::replace(output.second.begin(), output.second.end(), old_name, new_name); } need_update_ = true; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index fe3228ce5bf..d8851a8b42f 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -57,7 +57,8 @@ std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { return std::unique_ptr(BuildGradOp(&op)); } -static std::vector CreateGradOpDescs(const OpDescBind& op_desc) { +static std::vector> OpRegistry::CreateGradOpDescs( + const OpDescBind& op_desc) { auto& info = OpInfoMap::Instance().Get(op_desc.Type()); return info.grad_op_maker_(op_desc); } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index c80b6e9630b..e334cd592a2 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -69,7 +69,8 @@ class OpRegistry { static std::unique_ptr CreateGradOp(const OperatorBase& op); - static std::vector CreateGradOpDescs(const OpDescBind& op_desc); + static std::vector> CreateGradOpDescs( + const OpDescBind& op_desc); }; class Registrar { -- GitLab