diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 9830e4c09254e84bfc45234af14efcc4da513bb7..36d6cbb5eec8a22d3529a66dd6d2732ee156b026 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -154,7 +154,6 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_OP(fc, f::FcOp, f::FcOpMaker); -REGISTER_GRADIENT_OP(fc, fc_grad, f::EmptyOp); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); @@ -326,7 +325,6 @@ TEST(Backward, op_part_of_output_are_not_need) { d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); } -/* TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); @@ -344,7 +342,6 @@ TEST(Backward, op_part_of_input_are_not_need) { ASSERT_EQ(grad_mul.Input("B"), "b"); ASSERT_EQ(grad_mul.Input("Out"), "out"); } -*/ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { f::NetOp net; @@ -359,13 +356,19 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); - EXPECT_EQ(bwd_net->ops_[0]->type_, "fc_grad"); + EXPECT_EQ(bwd_net->ops_[0]->type_, ""); EXPECT_EQ(bwd_net->ops_[1]->type_, ""); EXPECT_EQ(bwd_net->ops_[2]->type_, ""); auto &grad_fc = *bwd_net->ops_[0]; EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); EXPECT_EQ(grad_fc.outputs_.size(), 3UL); + + EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); + EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), f::OperatorBase::EMPTY_VAR_NAME()); EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),