diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 8538ad9f0a95a19cba8bcf771166181793492eac..716e78f342ea5d7a5cb331e68fc7e7c6f7e4cfdf 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" +#include #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" @@ -71,7 +72,7 @@ static std::shared_ptr BackwardImpl( return EmptyOp(); } - // auto* net = new NetOp(); + auto* net = new NetOp(); if (forwardOp.IsNetOp()) { //! TODO(dzh) @@ -93,29 +94,32 @@ static std::shared_ptr BackwardImpl( } // unique the duplicate name auto uid = uniq_id++; - std::unordered_map insert_postion; + // TODO(dzh): more comment + typedef std::pair> Pos; + std::list insert_postion; for (auto& dup_output_op : dup_output_ops) { - std::string& name = dup_output_op.first; + const std::string& name = dup_output_op.first; auto& dup_op = dup_output_op.second; if (dup_op.size() == 1) continue; std::vector dup_outputs; for (size_t i = 0; i < dup_op.size(); ++i) { auto op_offset = dup_op[i]; - net->ops_[op_offset].Rename( - name, - name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i)); + dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + + std::to_string(i)); + net->ops_[op_offset]->Rename(name, dup_outputs.back()); } - insert_postion[op_offset] = - OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {}); - net->AddOp("Add"); - net->AddOp(); - // process shared variable - // while(dup_op.size()) { - // - // AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, - // {dup_op->inputs_}, {})); - //} + insert_postion.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "Add", {dup_outputs}, {name}, + {{"input_format", + std::vector{0, (int)dup_outputs.size()}}})}); + } + insert_postion.sort( + [](const Pos& l, const Pos& r) { return l.first > r.first; }); + for (auto& pos : insert_postion) { + net->InsertOp(pos.first, pos.second); } } else { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6f86b62b48d1d00d5c565241d5e78c57f1adf5db..0666bcc14cc3be9e1e9260e4c31cf4eec3a92e03 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -215,7 +215,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), all_output.end()); - ASSERT_EQ(2, bwd_net->ops_.size()); + ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3, first_fc_grad->ops_.size()); @@ -333,4 +333,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ASSERT_EQ(grad_fc.Input("W"), "w3"); ASSERT_EQ(grad_fc.Input("b"), "b3"); ASSERT_EQ(grad_fc.Input("Out"), "out3"); -} \ No newline at end of file +} diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 403d96a22de87ce551ebd330a10434f9be7ff609..2cd378c6b21303d1a24206ba3010b0d035aaa766 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -74,27 +74,5 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } -void NetOp::Rename(const std::unordered_map< - std::string, std::vector>& dup_output_ops, - size_t& uniq_id) { - for (auto& op : ops_) { - if (op->isNetOp()) { - op->Rename(dup_output_ops, uniq_id); - } - for (size_t i = 0; i < op->outputs_.size(); ++i) { - std::vector dup_outputs; - if (op->outputs_[i] ==) { - op->outputs_[i] += std::to_string(uniq_id++); - dup_outputs.push_back(op->outputs_[i]); - } - // add duplicate output together. replace with AddOp - if (dup_outputs.size() >= 2) { - AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs}, {op->inputs_}, - {})); - } - } - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index bc55c8ee05d1b37317cd9c230e1fccaa1613eba4..9c7f0eab73b4171d085241324bf8521cd4064af7 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -49,11 +49,6 @@ class NetOp : public OperatorBase { } } - /** - * @brief rename duplicated output gradient name in Net - */ - bool Rename(size_t& uniq_id); - /** * @brief Run the network. *