提交 1294b3c5 编写于 作者: Y Yu Yang 提交者: Qiao Longfei

Expose Net to Python (#2967)

* Expose Net to Python

* Expose PlainNet to Python, make python can add_op, complete_add_op
* Provide a low level api to manipulate Net
* Unittest for Net::DebugString
上级 d1dbf2cf
...@@ -39,19 +39,22 @@ void PlainNet::CompleteAddOp(bool calc) { ...@@ -39,19 +39,22 @@ void PlainNet::CompleteAddOp(bool calc) {
output_set.insert(opt); output_set.insert(opt);
} }
} }
inputs_.reserve(input_set.size()); inputs_.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_));
std::sort(inputs_.begin(), inputs_.end());
outputs_.reserve(output_set.size()); outputs_.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs_));
std::sort(outputs_.begin(), outputs_.end());
std::vector<int> tmp_index; std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size()); tmp_index.reserve(temp_output.size());
int idx = 0; int output_len = static_cast<int>(outputs_.size());
for (auto& opt : output_set) { for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, opt)) { if (Contains(temp_output, outputs_[i])) {
tmp_index.push_back(idx); tmp_index.push_back(i);
} }
outputs_.push_back(opt);
++idx;
} }
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
...@@ -59,9 +62,12 @@ void PlainNet::CompleteAddOp(bool calc) { ...@@ -59,9 +62,12 @@ void PlainNet::CompleteAddOp(bool calc) {
std::string PlainNet::DebugString() const { std::string PlainNet::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << this->type_ << ":" << std::endl; os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
os << "\t" << op->DebugString() << std::endl; std::istringstream is(op->DebugString());
for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl;
}
} }
return os.str(); return os.str();
} }
......
...@@ -13,16 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,16 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/scope.h>
#include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
...@@ -35,8 +37,19 @@ USE_OP(sigmoid); ...@@ -35,8 +37,19 @@ USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
template <typename ClassType>
void ExposeOperator(ClassType& m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> {
return op.outputs_;
})
.def("__str__", &ClassType::type::DebugString);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle"); py::module m("core", "C++ core of PaddlePaddle");
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol()) py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer([](pd::Tensor& self) -> py::buffer_info { .def_buffer([](pd::Tensor& self) -> py::buffer_info {
...@@ -113,10 +126,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -113,10 +126,9 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext(); return new paddle::platform::CPUDeviceContext();
}); });
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator") py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator");
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr {
[](py::bytes protobin) {
pd::OpDesc desc; pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
...@@ -124,10 +136,27 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -124,10 +136,27 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc); return pd::OpRegistry::CreateOp(desc);
});
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> plain_net(m, "PlainNet");
plain_net
.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(plain_net));
}) })
.def("infer_shape", &pd::OperatorBase::InferShape) .def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("run", &pd::OperatorBase::Run) .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); ExposeOperator(plain_net);
return m.ptr(); return m.ptr();
} }
...@@ -3,6 +3,7 @@ add_python_test(test_framework ...@@ -3,6 +3,7 @@ add_python_test(test_framework
test_scope.py test_scope.py
test_default_scope_funcs.py test_default_scope_funcs.py
test_op_creation_methods.py test_op_creation_methods.py
test_plain_net.py
test_tensor.py test_tensor.py
test_fc_op.py test_fc_op.py
test_add_two_op.py test_add_two_op.py
......
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.PlainNet.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
net.add_op(op1)
net2 = core.PlainNet.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.complete_add_op(True)
expected = '''
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0).
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
self.assertEqual(expected, "\n" + str(net))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册