提交 3dd5fd04 编写于 作者: F fengjiayi

Add unitest of Backward.intermediate_variable_not_need_in_linear_net

上级 099bb53b
......@@ -128,7 +128,7 @@ TEST(Backward, simple_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1, gop->inputs_.size());
ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
......@@ -265,4 +265,36 @@ TEST(Backward, part_of_input_are_not_need) {
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out");
}
TEST(Backward, intermediate_variable_not_need_in_linear_net) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {}));
net.CompleteAddOp(false);
auto backward = f::Backward(net, {"out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 1UL);
auto &grad_fc = *bwd_net->ops_[0];
ASSERT_EQ(grad_fc.type_, "fc_grad");
ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL);
ASSERT_EQ(grad_fc.outputs_.size(), 3UL);
ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("X"), "out2");
ASSERT_EQ(grad_fc.Input("W"), "w3");
ASSERT_EQ(grad_fc.Input("b"), "b3");
ASSERT_EQ(grad_fc.Input("Out"), "out3");
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册