diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index cb1d402526e0b0008263c8f19e4d175c6e1502f6..a481cb1b2a7ccfd41e5e07265b4809c53649e634 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" + #include #include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" @@ -161,12 +162,23 @@ TEST(Backward, simple_op_grad) { } TEST(Backward, simple_op_not_need_grad) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {}); + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); - auto gop = f::Backward(*fwd, {"x"}); - LOG(INFO) << gop->DebugString(); - ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()), + auto gop = f::Backward(*fwd, {"X"}); + LOG(INFO) << "full " << gop->DebugString(); + ASSERT_NE(std::find(gop->outputs_.begin(), gop->outputs_.end(), + "X" + f::OperatorBase::GRAD_VAR_SUFFIX()), gop->outputs_.end()); + auto no_input_gop = f::Backward(*fwd, {"X", "b"}); + LOG(INFO) << "no input gop " << no_input_gop->DebugString(); + ASSERT_NE(no_input_gop, nullptr); + ASSERT_EQ(std::vector{}, no_input_gop->outputs_); + ASSERT_EQ( + std::vector{"Out" + f::OperatorBase::GRAD_VAR_SUFFIX()}, + no_input_gop->inputs_); + // auto no_output_gop = f::Backward(*fwd, {"Out"}); + // ASSERT_EQ(std::vector{"X" + + // f::OperatorBase::GRAD_VAR_SUFFIX(), "b"}) } TEST(Backward, net_fc_backward_normal) {