diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 7ffe4c28103f9d6a9f179422d1beb86106ef786e..24a79d28b3858b000434cd698245c3222a1848c7 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -27,6 +27,8 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); +// TODO(someone): Add target as parameter and generate backward op +// according to target. void AppendBackward(ProgramDescBind& program_desc, const std::unordered_set& no_grad_vars); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 116c99bd2c1ca59b093392f9e6cc481c089309bc..807694fc08fe613ce7acfaddcee0493a9b877b7e 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" #include #include +#include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/program_desc.h" @@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) { py::return_value_policy::reference) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) + .def("backward", + [](ProgramDescBind &program_desc, + const std::unordered_set &no_grad_vars) { + AppendBackward(program_desc, no_grad_vars); + }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size); } diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index b82d1760d65a24401aaa336bc41f75ed60af8ae9..6eae378c91eb11b81891cf802e158a153f11d4e2 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,4 +1,6 @@ import unittest + +import paddle.v2.framework.core as core from paddle.v2.framework.graph import g_program @@ -31,6 +33,21 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) + def test_backward(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + sum_op_desc = block.append_op() + sum_op_desc.set_type("sum") + sum_op_desc.set_input("X", ["x1", "x2"]) + sum_op_desc.set_output("Out", ["out"]) + + self.assertEqual(len(block.all_ops()), 1) + prog.backward(set()) + self.assertEqual(len(block.all_ops()), 3) + if __name__ == '__main__': unittest.main()