From e8cad5a1d00967fb83ff9632672e0650a5f67af8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 10 Oct 2017 22:46:16 -0700 Subject: [PATCH] add more unit test for test_append_backward --- paddle/pybind/protobuf.cc | 2 +- .../paddle/v2/framework/tests/test_program.py | 27 ++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 807694fc0..0e7393942 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -117,7 +117,7 @@ void BindProgramDesc(py::module &m) { py::return_value_policy::reference) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) - .def("backward", + .def("append_backward", [](ProgramDescBind &program_desc, const std::unordered_set &no_grad_vars) { AppendBackward(program_desc, no_grad_vars); diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index 6eae378c9..83e184494 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -33,20 +33,33 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) - def test_backward(self): + def test_append_backward(self): prog = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) + mul_op_desc = block.append_op() + mul_op_desc.set_type("mul") + mul_op_desc.set_input("X", ["x1"]) + mul_op_desc.set_input("Y", ["y1"]) + mul_op_desc.set_output("Out", ["out1"]) + sum_op_desc = block.append_op() - sum_op_desc.set_type("sum") - sum_op_desc.set_input("X", ["x1", "x2"]) - sum_op_desc.set_output("Out", ["out"]) + sum_op_desc.set_type("elementwise_add") + sum_op_desc.set_input("X", ["out1"]) + sum_op_desc.set_input("Y", ["b1"]) + sum_op_desc.set_output("Out", ["out2"]) - self.assertEqual(len(block.all_ops()), 1) - prog.backward(set()) - self.assertEqual(len(block.all_ops()), 3) + expect_ops = [ + "mul", "elementwise_add", "elementwise_add_grad", "mul_grad" + ] + actual_ops = [] + prog.append_backward(set()) + for op in block.all_ops(): + actual_ops.append(op.type()) + print(actual_ops) + self.assertEqual(actual_ops, expect_ops) if __name__ == '__main__': -- GitLab