提交 a2e2cd77 编写于 作者: F fengjiayi

Fix bug of TEST Backwar.linear_net_intermediate_variable_has_no_grad

上级 39cd39e0
...@@ -356,36 +356,40 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -356,36 +356,40 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
EXPECT_EQ(bwd_net->ops_[0]->type_, "");
EXPECT_EQ(bwd_net->ops_[1]->type_, "");
EXPECT_EQ(bwd_net->ops_[2]->type_, "");
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); EXPECT_EQ(grad_fc.inputs_.size(),
EXPECT_EQ(grad_fc.outputs_.size(), 3UL); 3UL /* external input number */
+ 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/
- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */
+ 1UL /* input number of sigmod */);
std::cout << std::endl;
EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), /*
f::OperatorBase::EMPTY_VAR_NAME()); EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); "mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("add_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"tmp_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
EXPECT_EQ(grad_fc.Input("X"), "out2"); EXPECT_EQ(grad_fc.Input("X"), "out2");
EXPECT_EQ(grad_fc.Input("W"), "w3"); EXPECT_EQ(grad_fc.Input("W"), "w3");
EXPECT_EQ(grad_fc.Input("b"), "b3");
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3"); EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3"); EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
EXPECT_EQ(grad_fc.Input("Out"), "out3"); EXPECT_EQ(grad_fc.Input("Out"), "out3");
*/
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册