提交 75e16bd3 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3725 from QiJune/refine_backward

refine backward
...@@ -124,6 +124,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -124,6 +124,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
std::list<Pos> insert_position; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
// duplicate @Empty@ don't need to be added
if (name == kEmptyVarName) continue;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
// no duplicate output // no duplicate output
if (dup_op.size() == 1) continue; if (dup_op.size() == 1) continue;
...@@ -209,7 +212,7 @@ std::unique_ptr<OperatorBase> Backward( ...@@ -209,7 +212,7 @@ std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
no_grad_names.reserve(no_grad_vars.size()); no_grad_names.reserve(no_grad_vars.size() + 1);
no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix); no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
......
...@@ -31,10 +31,13 @@ void NetOp::CompleteAddOp(bool calc) { ...@@ -31,10 +31,13 @@ void NetOp::CompleteAddOp(bool calc) {
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->Inputs()) { for (auto& ipt : op->Inputs()) {
for (auto& var_name : ipt.second) { for (auto& var_name : ipt.second) {
if (!Contains(output_set, var_name)) { // Not other op's output // If input variable has been in output set, then it will be
input_set.insert(var_name); // added into intermediate_outputs_. Otherwise, it will be
} else { // added into input set.
if (Contains(output_set, var_name)) {
intermediate_outputs_.insert(var_name); intermediate_outputs_.insert(var_name);
} else {
input_set.insert(var_name);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册