提交 4b4f7457 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3467 from reyoung/feature/remove_input_format_in_backward

Remove input_format in backward.cc
......@@ -127,11 +127,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->ops_[op_offset]->Rename(name, dup_outputs.back());
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
{dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
{{"Out", {name}}}, {})});
}
insert_position.sort(
......@@ -140,7 +137,6 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second);
}
} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
......@@ -176,7 +172,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp();
return net;
}
} // namespace framework
// See header for comments
std::shared_ptr<OperatorBase> Backward(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册