提交 1294b3c5 编写于 作者: Y Yu Yang 提交者: Qiao Longfei

Expose Net to Python (#2967)

* Expose Net to Python

* Expose PlainNet to Python, make python can add_op, complete_add_op
* Provide a low level api to manipulate Net
* Unittest for Net::DebugString
上级 d1dbf2cf
......@@ -39,19 +39,22 @@ void PlainNet::CompleteAddOp(bool calc) {
output_set.insert(opt);
}
}
inputs_.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_));
std::sort(inputs_.begin(), inputs_.end());
outputs_.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs_));
std::sort(outputs_.begin(), outputs_.end());
std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size());
int idx = 0;
for (auto& opt : output_set) {
if (Contains(temp_output, opt)) {
tmp_index.push_back(idx);
int output_len = static_cast<int>(outputs_.size());
for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, outputs_[i])) {
tmp_index.push_back(i);
}
outputs_.push_back(opt);
++idx;
}
attrs_["temporary_index"] = tmp_index;
......@@ -59,9 +62,12 @@ void PlainNet::CompleteAddOp(bool calc) {
std::string PlainNet::DebugString() const {
std::ostringstream os;
os << this->type_ << ":" << std::endl;
os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) {
os << "\t" << op->DebugString() << std::endl;
std::istringstream is(op->DebugString());
for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl;
}
}
return os.str();
}
......
......@@ -13,16 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/scope.h>
#include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <fstream>
#include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace pd = paddle::framework;
......@@ -35,8 +37,19 @@ USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
template <typename ClassType>
void ExposeOperator(ClassType& m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> {
return op.outputs_;
})
.def("__str__", &ClassType::type::DebugString);
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle");
py::module m("core", "C++ core of PaddlePaddle");
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer([](pd::Tensor& self) -> py::buffer_info {
......@@ -113,10 +126,9 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext();
});
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create",
[](py::bytes protobin) {
py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr {
pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
......@@ -124,10 +136,27 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc);
});
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> plain_net(m, "PlainNet");
plain_net
.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(plain_net));
})
.def("infer_shape", &pd::OperatorBase::InferShape)
.def("run", &pd::OperatorBase::Run)
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
.def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
ExposeOperator(plain_net);
return m.ptr();
}
......@@ -3,6 +3,7 @@ add_python_test(test_framework
test_scope.py
test_default_scope_funcs.py
test_op_creation_methods.py
test_plain_net.py
test_tensor.py
test_fc_op.py
test_add_two_op.py
......
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.PlainNet.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
net.add_op(op1)
net2 = core.PlainNet.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.complete_add_op(True)
expected = '''
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0).
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
self.assertEqual(expected, "\n" + str(net))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册