From e9a0c4ef87134d061ba952bb89c0dfe01eedc37e Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 10 Oct 2017 19:57:30 -0700 Subject: [PATCH] expose AppendBackward of ProgramDesc to python --- paddle/framework/backward.h | 2 ++ paddle/pybind/protobuf.cc | 6 ++++++ .../paddle/v2/framework/tests/test_program.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 7ffe4c2810..24a79d28b3 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -27,6 +27,8 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); +// TODO(someone): Add target as parameter and generate backward op +// according to target. void AppendBackward(ProgramDescBind& program_desc, const std::unordered_set& no_grad_vars); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 116c99bd2c..807694fc08 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" #include #include +#include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/program_desc.h" @@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) { py::return_value_policy::reference) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) + .def("backward", + [](ProgramDescBind &program_desc, + const std::unordered_set &no_grad_vars) { + AppendBackward(program_desc, no_grad_vars); + }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size); } diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index b82d1760d6..6eae378c91 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,4 +1,6 @@ import unittest + +import paddle.v2.framework.core as core from paddle.v2.framework.graph import g_program @@ -31,6 +33,21 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) + def test_backward(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + self.assertEqual(len(block.all_ops()), 1) + prog.backward(set()) + self.assertEqual(len(block.all_ops()), 3) + if __name__ == '__main__': unittest.main() -- GitLab