From 04db4183e975ed3b2d07a57984dd5edf4a8adcb0 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 27 Jul 2017 14:26:17 +0800 Subject: [PATCH] Add unitest of Backward.part_of_input_are_not_need --- paddle/framework/backward_test.cc | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index dd0d2be668..878d3010de 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -166,7 +166,7 @@ TEST(Backward, part_of_output_are_not_need) { auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); - ASSERT_EQ(net->ops_.size(), 2); + ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; ASSERT_EQ("fill_zeros_like", fill_zero.type_); @@ -184,4 +184,23 @@ TEST(Backward, part_of_output_are_not_need) { d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX())); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX())); +} + +TEST(Backward, part_of_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 1UL); + + auto &grad_mul = *net->ops_[0]; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); } \ No newline at end of file -- GitLab