提交 e9a0c4ef 编写于 作者: Q qiaolongfei

expose AppendBackward of ProgramDesc to python

上级 d1479d93
...@@ -27,6 +27,8 @@ extern std::unique_ptr<OperatorBase> Backward( ...@@ -27,6 +27,8 @@ extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
// TODO(someone): Add target as parameter and generate backward op
// according to target.
void AppendBackward(ProgramDescBind& program_desc, void AppendBackward(ProgramDescBind& program_desc,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h" #include "paddle/pybind/protobuf.h"
#include <deque> #include <deque>
#include <iostream> #include <iostream>
#include "paddle/framework/backward.h"
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
...@@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) { ...@@ -116,6 +117,11 @@ void BindProgramDesc(py::module &m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("backward",
[](ProgramDescBind &program_desc,
const std::unordered_set<std::string> &no_grad_vars) {
AppendBackward(program_desc, no_grad_vars);
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size); .def("num_blocks", &ProgramDescBind::Size);
} }
......
import unittest import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.graph import g_program from paddle.v2.framework.graph import g_program
...@@ -31,6 +33,21 @@ class TestProgram(unittest.TestCase): ...@@ -31,6 +33,21 @@ class TestProgram(unittest.TestCase):
self.assertEqual(1, b.idx) self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx) self.assertEqual(0, b.parent_idx)
def test_backward(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
self.assertEqual(len(block.all_ops()), 1)
prog.backward(set())
self.assertEqual(len(block.all_ops()), 3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册