提交 21baa94b 编写于 作者: Y Yu Yang

Merge branch 'feature/expose_net_op' into feature/middle_level_net_api

...@@ -61,7 +61,10 @@ std::string PlainNet::DebugString() const { ...@@ -61,7 +61,10 @@ std::string PlainNet::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << this->type_ << ":" << std::endl; os << this->type_ << ":" << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
os << "\t" << op->DebugString() << std::endl; std::istringstream is(op->DebugString());
for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl;
}
} }
return os.str(); return os.str();
} }
......
...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/scope.h>
#include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
...@@ -29,6 +30,17 @@ namespace pd = paddle::framework; ...@@ -29,6 +30,17 @@ namespace pd = paddle::framework;
USE_OP(add_two); USE_OP(add_two);
USE_OP_WITHOUT_KERNEL(fc); USE_OP_WITHOUT_KERNEL(fc);
template <typename ClassType>
void ExposeOperator(ClassType& m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> {
return op.outputs_;
})
.def("__str__", &ClassType::type::DebugString);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle"); py::module m("core", "C++ core of Paddle Paddle");
...@@ -107,10 +119,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -107,10 +119,9 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext(); return new paddle::platform::CPUDeviceContext();
}); });
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator") py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator");
.def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr {
[](py::bytes protobin) {
pd::OpDesc desc; pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
...@@ -118,10 +129,26 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -118,10 +129,26 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc); return pd::OpRegistry::CreateOp(desc);
});
ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<pd::PlainNet> {
auto retv = std::make_shared<pd::PlainNet>();
retv->type_ = "naive_net";
return retv;
})
.def("add_op", &pd::PlainNet::AddOp)
.def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
}) })
.def("infer_shape", &pd::OperatorBase::InferShape) .def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("run", &pd::OperatorBase::Run) .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); ExposeOperator(net);
return m.ptr(); return m.ptr();
} }
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
net.add_op(op1)
net2 = core.Net.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.complete_add_op(True)
expected = '''naive_net:
Op(add_two), inputs:(X, Y), outputs:(Out).
naive_net:
fc:
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
'''
self.assertEqual(expected, str(net))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册