diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index ffdadd709f09d9191a9358c8f930215c382f56ea..9830e4c09254e84bfc45234af14efcc4da513bb7 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { } }; +class NoGradOpMaker : public OpProtoAndCheckerMaker { + public: + NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X input"); + AddOutput("Y", "Y output"); + AddComment("NoGradOp, same input output. no Grad"); + } +}; + class FcOp : public NetOp { public: void Init() override { @@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); +REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); @@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) { } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = - f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, - {{"temporary_index", std::vector{1}}}); + // auto fwd = + // f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, + // {{"temporary_index", std::vector{1}}}); + + auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } @@ -317,11 +330,7 @@ TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto backward = f::Backward(*fwd, {"a"}); - ASSERT_FALSE(backward->IsNetOp()); - auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 1UL); - - auto &grad_mul = *net->ops_[0]; + auto &grad_mul = *backward; ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);