提交 8bf0ca0f 编写于 作者: Y Yu Yang

Fix unittest error

上级 e1d10675
...@@ -152,8 +152,8 @@ TEST(Backward, simple_op_grad) { ...@@ -152,8 +152,8 @@ TEST(Backward, simple_op_grad) {
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {}); "fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -175,7 +175,8 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -175,7 +175,8 @@ TEST(Backward, net_fc_backward_normal) {
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {}); "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
{"out", "tmp_forward"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -194,9 +195,10 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -194,9 +195,10 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net; f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
net.AddOp( {"hidden0", "tmp0"}, {}));
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
{"hidden1", "tmp1"}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need. auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册