diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 609dc661f2abcf9e892209501550292e8ed4fdb0..6ab81b5589307efe83f725537195671794cd6dcb 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -128,6 +128,7 @@ TEST(Backward, simple_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ(1, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); @@ -138,6 +139,67 @@ TEST(Backward, simple_grad) { // LOG(INFO) << gop->Output("X" + "@GRAD"); } +TEST(Backward, fc_backward_normal) { + std::shared_ptr fwd = + f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(3UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_add = *net->ops_[1]; + ASSERT_EQ("rowwise_add_grad", d_add.type_); + + f::OperatorBase &d_mul = *net->ops_[2]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, fc_backward_not_have_b) { + std::shared_ptr fwd = f::OpRegistry::CreateOp( + "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {}); + ASSERT_NE(fwd, nullptr); + std::shared_ptr gop = f::Backward(*fwd, {}); + ASSERT_TRUE(gop->IsNetOp()); + auto net = static_cast(gop.get()); + + ASSERT_NO_THROW(net->DebugString()); + + ASSERT_EQ(2UL, net->ops_.size()); + + f::OperatorBase &d_sigmoid = *net->ops_[0]; + ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + + f::OperatorBase &d_mul = *net->ops_[1]; + ASSERT_EQ("mul_grad", d_mul.type_); +} + +TEST(Backward, input_layer_not_need_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {})); + net.AddOp( + f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {})); + + auto bwd = Backward(net, {"X"}); // X@GRAD is not need. + ASSERT_TRUE(bwd->IsNetOp()); + auto bwd_net = static_cast(bwd.get()); + + std::unordered_set all_output = std::unordered_set( + bwd_net->outputs_.begin(), bwd_net->outputs_.end()); + all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); + + for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { + ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), + all_output.end()); + } +} + TEST(Backward, not_for_network) { auto fwd = f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, @@ -166,7 +228,7 @@ TEST(Backward, part_of_output_are_not_need) { auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 2); + ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; ASSERT_EQ("fill_zeros_like", fill_zero.type_);