diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index dd0d2be6685ce794372291caac3c541a6a5b9be2..609dc661f2abcf9e892209501550292e8ed4fdb0 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -184,4 +184,17 @@ TEST(Backward, part_of_output_are_not_need) { d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"X"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(1UL, net->ops_.size()); + + auto &d_add = *net->ops_[0]; + ASSERT_EQ("rowwise_add_grad", d_add.type_); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), + d_add.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); } \ No newline at end of file