From 84198f75483aa9b7718c71d3bafa3372f73aef5a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 27 Jul 2017 16:06:43 +0800 Subject: [PATCH] Add unittest --- paddle/framework/backward_test.cc | 58 +++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index e920af3d1..81a55a42b 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { AddComment(""); } }; + +class AddOpMaker : public OpProtoAndCheckerMaker { + public: + AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x").SetMultiple(); + AddOutput("Y", "y"); + AddComment(""); + } +}; } // namespace framework } // namespace paddle @@ -123,12 +133,14 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); +REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); +REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); -TEST(Backward, simple_grad) { +TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1, gop->inputs_.size()); + ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); @@ -139,7 +151,7 @@ TEST(Backward, simple_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } -TEST(Backward, fc_backward_normal) { +TEST(Backward, net_fc_backward_normal) { std::shared_ptr fwd = f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); ASSERT_NE(fwd, nullptr); @@ -161,7 +173,7 @@ TEST(Backward, fc_backward_normal) { ASSERT_EQ("mul_grad", d_mul.type_); } -TEST(Backward, fc_backward_not_have_b) { +TEST(Backward, net_fc_backward_not_have_b) { std::shared_ptr fwd = f::OpRegistry::CreateOp( "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {}); ASSERT_NE(fwd, nullptr); @@ -180,12 +192,12 @@ TEST(Backward, fc_backward_not_have_b) { ASSERT_EQ("mul_grad", d_mul.type_); } -TEST(Backward, input_layer_not_need_grad) { +TEST(Backward, net_input_of_network_not_need_grad) { f::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {})); net.AddOp( f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {})); - + net.CompleteAddOp(); auto bwd = Backward(net, {"X"}); // X@GRAD is not need. ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); @@ -198,16 +210,40 @@ TEST(Backward, input_layer_not_need_grad) { ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), all_output.end()); } + + // Not Generated X + ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + + ASSERT_EQ(2, bwd_net->ops_.size()); + ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + ASSERT_EQ(3, first_fc_grad->ops_.size()); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), + first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, net_shared_weight) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); + net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); + net.CompleteAddOp(); + + auto bwd = f::Backward(net, {}); + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + ASSERT_EQ(3UL, bwd_net->ops_.size()); + ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); } -TEST(Backward, not_for_network) { +TEST(Backward, op_register_grad_not_for_network) { auto fwd = f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, {{"temporary_index", std::vector{1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } -TEST(Backward, all_input_are_not_need) { +TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"X", "b"}); ASSERT_TRUE(backward->IsNetOp()); @@ -215,7 +251,7 @@ TEST(Backward, all_input_are_not_need) { ASSERT_TRUE(net->ops_.empty()); } -TEST(Backward, all_output_are_not_need) { +TEST(Backward, op_all_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"Out"}); ASSERT_TRUE(backward->IsNetOp()); @@ -223,7 +259,7 @@ TEST(Backward, all_output_are_not_need) { ASSERT_TRUE(net->ops_.empty()); } -TEST(Backward, part_of_output_are_not_need) { +TEST(Backward, op_part_of_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); @@ -248,7 +284,7 @@ TEST(Backward, part_of_output_are_not_need) { d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); } -TEST(Backward, part_of_input_are_not_need) { +TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); ASSERT_TRUE(backward->IsNetOp()); -- GitLab