From 71bd439b45f36d4de5e0c06dfc013859d97684e3 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 28 Jul 2017 15:25:07 +0800 Subject: [PATCH] Addjust Backward.linear_net_intermediate_variable_has_no_grad --- paddle/framework/backward_test.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 7185872d0a0..ae85e6201be 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -325,14 +325,14 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"mul_out3", "tmp_out3", "out3"}, {})); net.CompleteAddOp(); - auto backward = f::Backward(net, {"out2"}); + auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 1UL); auto &grad_fc = *bwd_net->ops_[0]; ASSERT_EQ(grad_fc.type_, "fc_grad"); - ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL); + ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL); ASSERT_EQ(grad_fc.outputs_.size(), 3UL); ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()), f::OperatorBase::EMPTY_VAR_NAME()); @@ -340,10 +340,17 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { "w3" + f::OperatorBase::GRAD_VAR_SUFFIX()); ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()), "b3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("add_result" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "tmp_out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), "out3" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_fc.Input("X"), "out2"); ASSERT_EQ(grad_fc.Input("W"), "w3"); ASSERT_EQ(grad_fc.Input("b"), "b3"); + ASSERT_EQ(grad_fc.Input("mul_result"), "mul_out3"); + ASSERT_EQ(grad_fc.Input("add_result"), "tmp_out3"); ASSERT_EQ(grad_fc.Input("Out"), "out3"); } -- GitLab