From 494b3bda7d784315433b85826c9cbd18cac5723a Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 3 Oct 2017 10:28:57 -0700 Subject: [PATCH] fix backward test case --- paddle/framework/backward.cc | 2 +- paddle/framework/backward_test.cc | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 35759f8e78..2c13ddd8d0 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -159,7 +159,7 @@ static std::unique_ptr BackwardRecursive( insert_position.push_back( {dup_op.back(), OpRegistry::CreateOp( - "add", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, + "sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, {{"Out", {insert_add_out}}}, {})}); } } diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6932f5b989..a36e7bde8c 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { } }; -class AddOpMaker : public OpProtoAndCheckerMaker { +class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: - AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x").AsDuplicable(); - AddOutput("Out", "out"); + AddInput("X", "the input tensors of sum operator.") + .AsDuplicable() + .NotInGradient(); + AddOutput("Out", "the output tensor of sum operator.").NotInGradient(); AddComment(""); } }; + } // namespace framework } // namespace paddle @@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); +REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); @@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add", bwd_net->ops_[2]->Type()); + ASSERT_EQ("sum", bwd_net->ops_[2]->Type()); } TEST(Backward, op_register_grad_not_for_network) { -- GitLab