提交 42e2fa57 编写于 作者: Y Yu Yang

Fix unittest

上级 be528683
...@@ -46,6 +46,7 @@ static std::vector<size_t> InSetIdx( ...@@ -46,6 +46,7 @@ static std::vector<size_t> InSetIdx(
static std::shared_ptr<OperatorBase> EmptyOp() { static std::shared_ptr<OperatorBase> EmptyOp() {
auto net_op = std::make_shared<NetOp>(); auto net_op = std::make_shared<NetOp>();
net_op->type_ = "@EMPTY_OP@";
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return net_op;
} }
...@@ -140,7 +141,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -140,7 +141,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) { for (auto& pos : insert_position) {
net->InsertOp(pos.first, pos.second); net->InsertOp(pos.first + 1, pos.second);
} }
} else { } else {
...@@ -167,7 +168,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -167,7 +168,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
} }
net->AddOp(grad_op); net->AddOp(grad_op);
} }
net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp(); net->CompleteAddOp();
return net; return net;
} }
......
...@@ -269,15 +269,14 @@ TEST(Backward, net_shared_weight) { ...@@ -269,15 +269,14 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); ASSERT_EQ("add", bwd_net->ops_[2]->type_);
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
// auto fwd = auto fwd = f::OpRegistry::CreateOp(
// f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"},
// {{"temporary_index", std::vector<int>{1}}}); {{"temporary_index", std::vector<int>{0, 1}}});
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
...@@ -350,13 +349,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -350,13 +349,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
{"mul_out3", "tmp_out3", "out3"}, {})); {"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
LOG(INFO) << backward->DebugString();
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
EXPECT_EQ(bwd_net->ops_[0]->type_, "");
EXPECT_EQ(bwd_net->ops_[1]->type_, "");
EXPECT_EQ(bwd_net->ops_[2]->type_, "");
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
EXPECT_EQ(grad_fc.outputs_.size(), 3UL); EXPECT_EQ(grad_fc.outputs_.size(), 3UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册