From d0b25ac9b87225a31a2d9468ffb86a0ffe51b4c7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 28 Jul 2017 13:11:54 +0800 Subject: [PATCH] Fix some unittest error --- paddle/framework/backward.cc | 13 +++++++++---- paddle/framework/backward_test.cc | 30 ++++++++++++++++++++---------- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 1 + 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 2d9efdd5114..52eccfba697 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -72,7 +72,7 @@ static std::shared_ptr BackwardImpl( return EmptyOp(); } - auto* net = new NetOp(); + auto net = std::make_shared(); if (forwardOp.IsNetOp()) { //! TODO(dzh) @@ -84,7 +84,8 @@ static std::shared_ptr BackwardImpl( auto& forwardNet = static_cast(forwardOp); // travesal subnet/op - for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) { + for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); + ++it) { auto fwd = *it; // for (auto& fwd : forwardNet.ops_) { // auto bwd = Backward(*fwd, no_grad_names); @@ -115,7 +116,7 @@ static std::shared_ptr BackwardImpl( insert_postion.push_back( {dup_op.back(), OpRegistry::CreateOp( - "Add", {dup_outputs}, {name}, + "add", {dup_outputs}, {name}, {{"input_format", std::vector{0, (int)dup_outputs.size()}}})}); } @@ -142,11 +143,15 @@ static std::shared_ptr BackwardImpl( grad_output = OperatorBase::EMPTY_VAR_NAME(); } } + + if (net->ops_.empty()) { // Current no aux op is added to network + return grad_op; + } net->AddOp(grad_op); } net->CompleteAddOp(); - return std::shared_ptr(net); + return net; } extern std::shared_ptr Backward( diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 60fbb486888..63194e78fcf 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -63,14 +63,22 @@ class FcOp : public NetOp { public: void Init() override { AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, - {Output("before_act")}, {})); + {Output("mul_result")}, {})); auto b_name = Input("b"); + std::string before_act = "mul_result"; if (b_name != EMPTY_VAR_NAME()) { - AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, - {Output("before_act")}, {})); + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, + {Output("add_result")}, {})); + before_act = "add_result"; + } else { + auto out_varname = Output("add_result"); + if (out_varname != EMPTY_VAR_NAME()) { + this->Rename(out_varname, EMPTY_VAR_NAME()); + } } - AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")}, - {Output("Out")}, {})); + + AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")}, + {})); CompleteAddOp(false); } }; @@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "x"); AddInput("W", "w"); AddInput("b", "b"); - AddOutput("before_act", "before act").SetTemporary(); + AddOutput("mul_result", "").SetTemporary(); + AddOutput("add_result", "").SetTemporary(); AddOutput("Out", ""); AddComment(""); } @@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) { TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp( - "fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {}); + "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); @@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_not_have_b) { std::shared_ptr fwd = f::OpRegistry::CreateOp( "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, - {"out", "tmp_forward"}, {}); + {"mul_result", "add_result", "tmp"}, {}); ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); @@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_input_of_network_not_need_grad) { f::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, - {"hidden0", "tmp0"}, {})); + {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, - {"hidden1", "tmp1"}, {})); + {"mul_tmp_1", "add_tmp_1", "hidden1"}, {})); net.CompleteAddOp(); auto bwd = Backward(net, {"X"}); // X@GRAD is not need. ASSERT_TRUE(bwd->IsNetOp()); @@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); + LOG(INFO) << bwd_net->DebugString(); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3ad9dc2d7bf..646269074cc 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -52,7 +52,7 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), + PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(), "Input Out Of Range"); return std::vector{ @@ -78,7 +78,7 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), + PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index eecf2f8302d..358ab841d6c 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -101,6 +101,7 @@ class OperatorBase { //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; + //! Get a input which has multiple variables. //! TODO add a vector_view to prevent memory copy. std::vector Inputs(const std::string& name) const; -- GitLab