diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 81a55a42b483d34291ccb32e7284820cfb3b9974..6f86b62b48d1d00d5c565241d5e78c57f1adf5db 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -301,4 +301,36 @@ TEST(Backward, op_part_of_input_are_not_need) { "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("A"), "a"); + ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("Out"), "out"); +} + +TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + f::NetOp net; + net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {})); + net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {})); + net.CompleteAddOp(false); + auto backward = f::Backward(net, {"out2"}); + ASSERT_TRUE(backward->IsNetOp()); + auto bwd_net = static_cast(backward.get()); + ASSERT_EQ(bwd_net->ops_.size(), 1UL); + + auto &grad_fc = *bwd_net->ops_[0]; + ASSERT_EQ(grad_fc.type_, "fc_grad"); + ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL); + ASSERT_EQ(grad_fc.outputs_.size(), 3UL); + ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("X"), "out2"); + ASSERT_EQ(grad_fc.Input("W"), "w3"); + ASSERT_EQ(grad_fc.Input("b"), "b3"); + ASSERT_EQ(grad_fc.Input("Out"), "out3"); } \ No newline at end of file