提交 186fb0c1 编写于 作者: Y Yu Yang

Remove input_format in backward.cc

上级 ffbb0be2
...@@ -127,11 +127,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -127,11 +127,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
OpRegistry::CreateOp( {{"Out", {name}}}, {})});
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
} }
insert_position.sort( insert_position.sort(
...@@ -140,7 +137,6 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -140,7 +137,6 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (auto& pos : insert_position) { for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second); net->InsertOp(pos.first + 1, pos.second);
} }
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
...@@ -176,7 +172,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -176,7 +172,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->type_ = "@GENERATED_BACKWARD@"; net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp(); net->CompleteAddOp();
return net; return net;
} } // namespace framework
// See header for comments // See header for comments
std::shared_ptr<OperatorBase> Backward( std::shared_ptr<OperatorBase> Backward(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册