From 4f0c071e4909ff041f3a86c3a40c482becf50845 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 28 Aug 2017 22:18:11 +0800 Subject: [PATCH] refine backward --- paddle/framework/backward.cc | 5 ++++- paddle/operators/net_op.cc | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index bfda18724cc..6b4c612cd8d 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -124,6 +124,9 @@ static std::unique_ptr BackwardRecursive( std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; + // duplicate @Empty@ don't need to be added + if (name == kEmptyVarName) continue; + auto& dup_op = dup_output_op.second; // no duplicate output if (dup_op.size() == 1) continue; @@ -209,7 +212,7 @@ std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; - no_grad_names.reserve(no_grad_vars.size()); + no_grad_names.reserve(no_grad_vars.size() + 1); no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 44d925f0b0c..78b5e276784 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -31,10 +31,13 @@ void NetOp::CompleteAddOp(bool calc) { for (auto& op : ops_) { for (auto& ipt : op->Inputs()) { for (auto& var_name : ipt.second) { - if (!Contains(output_set, var_name)) { // Not other op's output - input_set.insert(var_name); - } else { + // If input variable has been in output set, then it will be + // added into intermediate_outputs_. Otherwise, it will be + // added into input set. + if (Contains(output_set, var_name)) { intermediate_outputs_.insert(var_name); + } else { + input_set.insert(var_name); } } } -- GitLab