diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index e920af3d1ac511f360ba32630c7812a939c27428..538522bf441ab71be6e76505e282a86b3878f267 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -128,7 +128,7 @@ TEST(Backward, simple_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1, gop->inputs_.size()); + ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); @@ -265,4 +265,36 @@ TEST(Backward, part_of_input_are_not_need) { "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("A"), "a"); + ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("Out"), "out"); +} + +TEST(Backward, intermediate_variable_not_need_in_linear_net) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); + net.CompleteAddOp(false); + auto backward = f::Backward(net, {"out2"}); + ASSERT_TRUE(backward->IsNetOp()); + auto bwd_net = static_cast(backward.get()); + ASSERT_EQ(bwd_net->ops_.size(), 1UL); + + auto &grad_fc = *bwd_net->ops_[0]; + ASSERT_EQ(grad_fc.type_, "fc_grad"); + ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL); + ASSERT_EQ(grad_fc.outputs_.size(), 3UL); + ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("X"), "out2"); + ASSERT_EQ(grad_fc.Input("W"), "w3"); + ASSERT_EQ(grad_fc.Input("b"), "b3"); + ASSERT_EQ(grad_fc.Input("Out"), "out3"); } \ No newline at end of file