diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index 7ffe4c28103f9d6a9f179422d1beb86106ef786e..f1ab8056450c96f0a1b671e1efa46c4c68f9ea15 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -27,6 +27,8 @@ extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); +// TODO(jiayi): Add target as parameter and generate backward op +// according to target. void AppendBackward(ProgramDescBind& program_desc, const std::unordered_set& no_grad_vars); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 116c99bd2c1ca59b093392f9e6cc481c089309bc..0e73939424633ea61c85db25000e0f21b8a4a402 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" #include #include +#include "paddle/framework/backward.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/program_desc.h" @@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) { py::return_value_policy::reference) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) + .def("append_backward", + [](ProgramDescBind &program_desc, + const std::unordered_set &no_grad_vars) { + AppendBackward(program_desc, no_grad_vars); + }) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("num_blocks", &ProgramDescBind::Size); } diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index b82d1760d65a24401aaa336bc41f75ed60af8ae9..83e184494ad235f6493a7ea8e25886b1e35004ee 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,4 +1,6 @@ import unittest + +import paddle.v2.framework.core as core from paddle.v2.framework.graph import g_program @@ -31,6 +33,34 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) + def test_append_backward(self): + prog = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + mul_op_desc = block.append_op() + mul_op_desc.set_type("mul") + mul_op_desc.set_input("X", ["x1"]) + mul_op_desc.set_input("Y", ["y1"]) + mul_op_desc.set_output("Out", ["out1"]) + + sum_op_desc = block.append_op() + sum_op_desc.set_type("elementwise_add") + sum_op_desc.set_input("X", ["out1"]) + sum_op_desc.set_input("Y", ["b1"]) + sum_op_desc.set_output("Out", ["out2"]) + + expect_ops = [ + "mul", "elementwise_add", "elementwise_add_grad", "mul_grad" + ] + actual_ops = [] + prog.append_backward(set()) + for op in block.all_ops(): + actual_ops.append(op.type()) + print(actual_ops) + self.assertEqual(actual_ops, expect_ops) + if __name__ == '__main__': unittest.main()