diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index dd0d2be6685ce794372291caac3c541a6a5b9be2..878d3010deda5ab24468e86b4f06580c649e147a 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -166,7 +166,7 @@ TEST(Backward, part_of_output_are_not_need) { auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 2); + ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; ASSERT_EQ("fill_zeros_like", fill_zero.type_); @@ -184,4 +184,23 @@ TEST(Backward, part_of_output_are_not_need) { d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 1UL); + + auto &grad_mul = *net->ops_[0]; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); } \ No newline at end of file