提交 39cd39e0 编写于 作者: F fengjiayi

Update test

上级 dc06eaa0
...@@ -154,7 +154,6 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); ...@@ -154,7 +154,6 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_GRADIENT_OP(fc, fc_grad, f::EmptyOp);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
...@@ -326,7 +325,6 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -326,7 +325,6 @@ TEST(Backward, op_part_of_output_are_not_need) {
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
} }
/*
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
...@@ -344,7 +342,6 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -344,7 +342,6 @@ TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ(grad_mul.Input("B"), "b"); ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out"); ASSERT_EQ(grad_mul.Input("Out"), "out");
} }
*/
TEST(Backward, linear_net_intermediate_variable_has_no_grad) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net; f::NetOp net;
...@@ -359,13 +356,19 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -359,13 +356,19 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
EXPECT_EQ(bwd_net->ops_[0]->type_, "fc_grad"); EXPECT_EQ(bwd_net->ops_[0]->type_, "");
EXPECT_EQ(bwd_net->ops_[1]->type_, ""); EXPECT_EQ(bwd_net->ops_[1]->type_, "");
EXPECT_EQ(bwd_net->ops_[2]->type_, ""); EXPECT_EQ(bwd_net->ops_[2]->type_, "");
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
EXPECT_EQ(grad_fc.outputs_.size(), 3UL); EXPECT_EQ(grad_fc.outputs_.size(), 3UL);
EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME()); f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册