提交 42e2fa57 编写于 作者: Y Yu Yang

Fix unittest

上级 be528683
......@@ -46,6 +46,7 @@ static std::vector<size_t> InSetIdx(
static std::shared_ptr<OperatorBase> EmptyOp() {
auto net_op = std::make_shared<NetOp>();
net_op->type_ = "@EMPTY_OP@";
net_op->CompleteAddOp();
return net_op;
}
......@@ -140,7 +141,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
[](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) {
net->InsertOp(pos.first, pos.second);
net->InsertOp(pos.first + 1, pos.second);
}
} else {
......@@ -167,7 +168,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
}
net->AddOp(grad_op);
}
net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp();
return net;
}
......
......@@ -269,15 +269,14 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
ASSERT_EQ("add", bwd_net->ops_[2]->type_);
}
TEST(Backward, op_register_grad_not_for_network) {
// auto fwd =
// f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
// {{"temporary_index", std::vector<int>{1}}});
auto fwd = f::OpRegistry::CreateOp(
"fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"},
{{"temporary_index", std::vector<int>{0, 1}}});
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
......@@ -350,13 +349,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
{"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
LOG(INFO) << backward->DebugString();
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
EXPECT_EQ(bwd_net->ops_[0]->type_, "");
EXPECT_EQ(bwd_net->ops_[1]->type_, "");
EXPECT_EQ(bwd_net->ops_[2]->type_, "");
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
EXPECT_EQ(grad_fc.outputs_.size(), 3UL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册