From 0380bfb3cf693cf233a3ad5fa4382fc65a2c7a02 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 19 Jul 2017 17:47:06 +0800 Subject: [PATCH] Expose Net to Python * Expose PlainNet to Python, make python can add_op, complete_add_op * Provide a low level api to manipulate Net * Unittest for Net::DebugString --- paddle/framework/net.cc | 5 +- paddle/pybind/pybind.cc | 71 ++++++++++++++------ python/paddle/v2/framework/tests/test_net.py | 28 ++++++++ 3 files changed, 81 insertions(+), 23 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_net.py diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 501536657d7..407a69fda60 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -61,7 +61,10 @@ std::string PlainNet::DebugString() const { std::ostringstream os; os << this->type_ << ":" << std::endl; for (auto& op : ops_) { - os << "\t" << op->DebugString() << std::endl; + std::istringstream is(op->DebugString()); + for (std::string line; std::getline(is, line);) { + os << " " << line << std::endl; + } } return os.str(); } diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 7e84550f770..bd126f0e97a 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include -#include -#include -#include -#include -#include -#include #include #include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" +#include "paddle/pybind/tensor_bind.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" namespace py = pybind11; namespace pd = paddle::framework; @@ -29,6 +30,17 @@ namespace pd = paddle::framework; USE_OP(add_two); USE_OP_WITHOUT_KERNEL(fc); +template +void ExposeOperator(ClassType& m) { + m.def("infer_shape", &ClassType::type::InferShape) + .def("run", &ClassType::type::Run) + .def("outputs", + [](const typename ClassType::type& op) -> std::vector { + return op.outputs_; + }) + .def("__str__", &ClassType::type::DebugString); +} + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -107,21 +119,36 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CPUDeviceContext(); }); - py::class_(m, "Operator") - .def("__str__", &pd::OperatorBase::DebugString) - .def_static("create", - [](py::bytes protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }) - .def("infer_shape", &pd::OperatorBase::InferShape) - .def("run", &pd::OperatorBase::Run) - .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); + py::class_ operator_base(m, "Operator"); + + operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }); + ExposeOperator(operator_base); + + using PlainNetPtr = std::shared_ptr; + py::class_ net(m, "Net"); + + net.def_static("create", + []() -> std::shared_ptr { + auto retv = std::make_shared(); + retv->type_ = "naive_net"; + return retv; + }) + .def("add_op", &pd::PlainNet::AddOp) + .def("add_op", + [](PlainNetPtr& self, const PlainNetPtr& net) -> void { + self->AddOp(std::static_pointer_cast(net)); + }) + .def("complete_add_op", &pd::PlainNet::CompleteAddOp) + .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); + ExposeOperator(net); return m.ptr(); } diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py new file mode 100644 index 00000000000..6a97c24990a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_net.py @@ -0,0 +1,28 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.create_op_creation_methods import op_creations +import unittest + + +class TestNet(unittest.TestCase): + def test_net_all(self): + net = core.Net.create() + op1 = op_creations.add_two(X="X", Y="Y", Out="Out") + net.add_op(op1) + + net2 = core.Net.create() + net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out")) + net2.complete_add_op(True) + net.add_op(net2) + net.complete_add_op(True) + expected = '''naive_net: + Op(add_two), inputs:(X, Y), outputs:(Out). + naive_net: + fc: + Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0). + Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out). +''' + self.assertEqual(expected, str(net)) + + +if __name__ == '__main__': + unittest.main() -- GitLab