提交 71bd439b 编写于 作者: F fengjiayi

Addjust Backward.linear_net_intermediate_variable_has_no_grad

上级 29d50ad9
...@@ -325,14 +325,14 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -325,14 +325,14 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
{"mul_out3", "tmp_out3", "out3"}, {})); {"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto backward = f::Backward(net, {"out2"}); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 1UL); ASSERT_EQ(bwd_net->ops_.size(), 1UL);
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
ASSERT_EQ(grad_fc.type_, "fc_grad"); ASSERT_EQ(grad_fc.type_, "fc_grad");
ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL); ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
ASSERT_EQ(grad_fc.outputs_.size(), 3UL); ASSERT_EQ(grad_fc.outputs_.size(), 3UL);
ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME()); f::OperatorBase::EMPTY_VAR_NAME());
...@@ -340,10 +340,17 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -340,10 +340,17 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("add_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"tmp_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("X"), "out2"); ASSERT_EQ(grad_fc.Input("X"), "out2");
ASSERT_EQ(grad_fc.Input("W"), "w3"); ASSERT_EQ(grad_fc.Input("W"), "w3");
ASSERT_EQ(grad_fc.Input("b"), "b3"); ASSERT_EQ(grad_fc.Input("b"), "b3");
ASSERT_EQ(grad_fc.Input("mul_result"), "mul_out3");
ASSERT_EQ(grad_fc.Input("add_result"), "tmp_out3");
ASSERT_EQ(grad_fc.Input("Out"), "out3"); ASSERT_EQ(grad_fc.Input("Out"), "out3");
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册