diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 4c76326e7cecf3dc9a0298864803578c05d49d3b..a84262e0075aa4e29197acb4b10fbf3c5712af0d 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -14,9 +14,11 @@ #include "paddle/framework/backward.h" +#include #include #include +#include "paddle/framework/block_desc.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -254,23 +256,22 @@ static bool AllGradInSet(const std::vector& names, std::vector> MakeGradOpDescs( const std::unique_ptr& op_desc, - unordered_set& no_grad_vars) { + std::unordered_set& no_grad_vars) { std::vector> grad_op_descs; // All input gradients of forwarding operator do not need to calculat. - if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix, - no_grad_vars)) { + if (AllGradInSet(op_desc->InputArgumentNames(), no_grad_vars)) { return grad_op_descs; // empty vector } // All output gradients of forwarding operator do not need to calculate. - const std::vector& outputs = op_desc->OutputArugumentNames(); - if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) { + const std::vector& outputs = op_desc->OutputArgumentNames(); + if (AllGradInSet(outputs, no_grad_vars)) { for (const std::string& name : outputs) { no_grad_vars.insert(GradVarName(name)); } return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc); + grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { @@ -280,43 +281,43 @@ std::vector> MakeGradOpDescs( 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); std::string new_name = prefix + kZeroVarSuffix; desc->Rename(in_name, new_name); - OpDescBind* fill_zeros_op = new OpDescBind( - "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}); - pending_fill_zeros_ops.push_back({fill_zeros_op}); + std::unique_ptr fill_zeros_op(new OpDescBind( + "fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {})); + pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); } } - for (const std::string& out_name : desc->OutputArgumentName()) { + for (const std::string& out_name : desc->OutputArgumentNames()) { if (no_grad_vars.count(out_name)) { desc->Rename(out_name, kEmptyVarName); } } } - grad_op_descs.insert(std::begin(grad_op_descs), - std::begin(pending_fill_zeros_ops), - std::end(pending_fill_zeros_ops)); + for (auto& p : pending_fill_zeros_ops) { + grad_op_descs.push_back(std::move(p)); + } - // TODO (fengjiayi): RNN op + // TODO(fengjiayi): RNN op return grad_op_descs; } -void AppendBackwardOpDescs( - BlockDescBind& block_desc, - const std::unordered_set& no_grad_vars) { +void AppendBackwardOpDescs(BlockDescBind& block_desc, + std::unordered_set& no_grad_vars) { std::unordered_map> dup_out_ops; size_t grad_desc_idx = 0; - std::deque> block_op_descs = block_desc.ops_; + std::deque>& block_op_descs = block_desc.ops_; std::vector> backward_descs; for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) { std::vector> op_grads = MakeGradOpDescs(*it, no_grad_vars); for (const auto& desc : op_grads) { - for (const std::string& out_name : desc->OutputArugumentNames()) { + for (const std::string& out_name : desc->OutputArgumentNames()) { dup_out_ops[out_name].emplace_back(grad_desc_idx); } ++grad_desc_idx; } - backward_descs.insert(backward_descs.end(), op_grads.begin(), - op_grads.end()); + std::transform( + op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs), + [](std::unique_ptr& ptr) { return std::move(ptr); }); } // Check whether some variables are written more than once std::list>> pending_sum_ops; @@ -330,9 +331,9 @@ void AppendBackwardOpDescs( backward_descs[dup_op[i]]->Rename(out_name, new_name); sum_op_inputs.emplace_back(new_name); } - OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}}, - {{"Out", {out_name}}}, {}); - pending_sum_ops.push_back({dup_op.back(), {sum_op}}); + std::unique_ptr sum_op(new OpDescBind( + "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {})); + pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); } } pending_sum_ops.sort( @@ -345,8 +346,9 @@ void AppendBackwardOpDescs( std::move(p.second)); } // Append backward_descs to BlockDescBind::ops_ - block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs), - std::end(backward_descs)); + for (std::unique_ptr& ptr : backward_descs) { + block_op_descs.push_back(std::move(ptr)); + } return; } diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index a171dfef30ad4740df92a5584633372ba643a6c3..fd95ef19012754315de2e36648291b65abb8a4d8 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -32,9 +32,8 @@ class ProgramDescBind; class BlockDescBind { public: - friend void AppendBackwardOps( - BlockDescBind &block_desc, - const std::unordered_set &no_grad_vars); + friend void AppendBackwardOpDescs( + BlockDescBind &block_desc, std::unordered_set &no_grad_vars); BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {}