提交 494b3bda 编写于 作者: D dongzhihong

fix backward test case

上级 3dc4f46f
......@@ -159,7 +159,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
}
......
......@@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable();
AddOutput("Out", "out");
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment("");
}
};
} // namespace framework
} // namespace paddle
......@@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);
......@@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->Type());
ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
}
TEST(Backward, op_register_grad_not_for_network) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册