提交 494b3bda 编写于 作者: D dongzhihong

fix backward test case

上级 3dc4f46f
...@@ -159,7 +159,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -159,7 +159,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp(
"add", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, "sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
{{"Out", {insert_add_out}}}, {})}); {{"Out", {insert_add_out}}}, {})});
} }
} }
......
...@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { ...@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class AddOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable(); AddInput("X", "the input tensors of sum operator.")
AddOutput("Out", "out"); .AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(""); AddComment("");
} }
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); ...@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP); f::NOP);
...@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) { ...@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->Type()); ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册