From 3e99b166ba147b8d954332a9be882bee25ca6591 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 1 Oct 2017 08:29:09 +0000 Subject: [PATCH] add generic add operator --- paddle/framework/backward.cc | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0ec18de5b8a..c625c0caf7d 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -141,9 +141,35 @@ static std::unique_ptr BackwardRecursive( net->ops_[op_offset]->Rename(name, dup_outputs.back()); } // collect all the offset to append `add` op for each alias - insert_position.push_back( - {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, - {{"Out", {name}}}, {})}); + // + // one variable is shared between multiple operators. + // insert add operator one by one, then add it to output + if (dup_outputs.size() == 2) { + insert_position.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "add", {{"X", {dup_outputs[0]}}, {"Y", {dup_outputs[1]}}}, + {{"Out", {name}}}, {})}); + } else { + for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; + ++output_idx) { + auto insert_add_x = dup_outputs[output_idx]; + auto insert_add_y = dup_outputs[output_idx]; + auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); + // first add op inserted + if (output_idx == dup_outputs.size() - 1) { + insert_add_out = name; + } + if (output_idx != 0) { + insert_add_y = name + "@SHARED@" + std::to_string(output_idx); + } + insert_position.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "add", {{"X", {insert_add_x}}, {"Y", {insert_add_y}}}, + {{"Out", {insert_add_out}}}, {})}); + } + } } // make sure the inserted `add` ops follow the BFS order. -- GitLab