提交 7454ec04 编写于 作者: L Liu Yiqun

Simplify backward when inserting a sum operator to accumulate all duplicated variables.

上级 23407e7a
......@@ -172,30 +172,14 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back());
}
// collect all the offset to append `add` op for each alias
//
// one variable is shared between multiple operators.
// insert add operator one by one, then add it to output
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) {
auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx + 1];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted
if (output_idx == dup_outputs.size() - 2) {
insert_add_out = name;
}
if (output_idx != 0) {
insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
// collect all the offset for each alias,
// insert a sum operator to add all aliases to output
insert_position.push_back(
{dup_op.back(), OpRegistry::CreateOp("sum", {{"X", dup_outputs}},
{{"Out", {name}}}, {})});
}
// make sure the inserted `add` ops follow the BFS order.
// make sure the inserted `sum` ops follow the BFS order.
insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册