提交 3fc68f6f 编写于 作者: Y Yu Yang

Move pybind.cc/tensor_bind.h to paddle::framework

Fix #3171
上级 bfaea910
...@@ -15,7 +15,6 @@ if(Boost_FOUND) ...@@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators) add_subdirectory(operators)
add_subdirectory(pybind)
endif() endif()
if(WITH_C_API) if(WITH_C_API)
......
...@@ -36,3 +36,12 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net) ...@@ -36,3 +36,12 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net) cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_network_op)
...@@ -20,13 +20,12 @@ limitations under the License. */ ...@@ -20,13 +20,12 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/pybind/tensor_bind.h" #include "paddle/framework/tensor_bind.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework;
USE_OP(add_two); USE_OP(add_two);
USE_OP(onehot_cross_entropy); USE_OP(onehot_cross_entropy);
...@@ -38,13 +37,14 @@ USE_OP(sigmoid); ...@@ -38,13 +37,14 @@ USE_OP(sigmoid);
USE_OP(softmax); USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
USE_OP_WITHOUT_KERNEL(recurrent_op); USE_OP_WITHOUT_KERNEL(recurrent_op);
namespace paddle {
namespace framework {
template <typename ClassType> template <typename ClassType>
void ExposeOperator(ClassType& m) { void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape) m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run) .def("run", &ClassType::type::Run)
.def("outputs", .def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> { [](const typename ClassType::type &op) -> std::vector<std::string> {
return op.outputs_; return op.outputs_;
}) })
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
...@@ -58,68 +58,58 @@ static size_t UniqueIntegerGenerator() { ...@@ -58,68 +58,58 @@ static size_t UniqueIntegerGenerator() {
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol()) py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer([](pd::Tensor& self) -> py::buffer_info { .def_buffer(
return paddle::pybind::CastToPyBuffer(self); [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
})
.def("get_dims", .def("get_dims",
[](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) [](const Tensor &self) { return vectorize(self.dims()); })
.def("set_dims", .def("set_dims",
[](pd::Tensor& self, const std::vector<int>& dim) { [](Tensor &self, const std::vector<int> &dim) {
self.Resize(pd::make_ddim(dim)); self.Resize(make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](Tensor &self) {
self.mutable_data<float>(paddle::platform::CPUPlace()); self.mutable_data<float>(paddle::platform::CPUPlace());
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self) { [](Tensor &self) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(paddle::platform::CPUPlace());
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>) .def("set", PyTensorSetFromArray<int>)
.def("shape", .def("shape", [](Tensor &self) { return vectorize(self.dims()); });
[](pd::Tensor& self) { return pd::vectorize(self.dims()); });
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle. All parameter, weight, gradient are variables in Paddle.
)DOC") )DOC")
.def("is_int", [](const pd::Variable& var) { return var.IsType<int>(); }) .def("is_int", [](const Variable &var) { return var.IsType<int>(); })
.def("set_int", .def("set_int",
[](pd::Variable& var, int val) -> void { [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
*var.GetMutable<int>() = val; .def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
})
.def("get_int",
[](const pd::Variable& var) -> int { return var.Get<int>(); })
.def("get_tensor", .def("get_tensor",
[](pd::Variable& self) -> pd::Tensor* { [](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
return self.GetMutable<pd::Tensor>();
},
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_net", .def("get_net",
[](pd::Variable& self) -> pd::NetOp* { [](Variable &self) -> NetOp * { return self.GetMutable<NetOp>(); },
return self.GetMutable<pd::NetOp>();
},
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<pd::Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
.def("new_var", .def("new_var",
[](pd::Scope& self, const std::string& name) -> pd::Variable* { [](Scope &self, const std::string &name) -> Variable * {
return self.NewVar(name); return self.NewVar(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("find_var", &pd::Scope::FindVar, py::return_value_policy::reference) .def("find_var", &Scope::FindVar, py::return_value_policy::reference)
.def(py::init<>()) .def(py::init<>())
.def("new_scope", .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
[](pd::Scope& self) -> pd::Scope* { return &self.NewScope(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("drop_kids", &pd::Scope::DropKids); .def("drop_kids", &Scope::DropKids);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto& protos = pd::OpRegistry::protos(); auto &protos = OpRegistry::protos();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = protos.begin(); it != protos.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), PADDLE_ENFORCE(it->second.IsInitialized(),
...@@ -134,47 +124,49 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -134,47 +124,49 @@ All parameter, weight, gradient are variables in Paddle.
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* { .def_static("cpu_context", []() -> paddle::platform::DeviceContext * {
return new paddle::platform::CPUDeviceContext(); return new paddle::platform::CPUDeviceContext();
}); });
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base( py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) { operator_base.def_static("create", [](py::bytes protobin) {
pd::OpDesc desc; OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}); });
ExposeOperator(operator_base); ExposeOperator(operator_base);
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net"); py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net");
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<pd::NetOp> { []() -> std::shared_ptr<NetOp> {
auto retv = std::make_shared<pd::NetOp>(); auto retv = std::make_shared<NetOp>();
retv->type_ = "plain_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &pd::NetOp::AddOp) .def("add_op", &NetOp::AddOp)
.def("add_op", .def("add_op",
[](pd::NetOp& self, const std::shared_ptr<pd::NetOp>& net) -> void { [](NetOp &self, const std::shared_ptr<NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<pd::OperatorBase>(net)); self.AddOp(std::static_pointer_cast<OperatorBase>(net));
}) })
.def("complete_add_op", &pd::NetOp::CompleteAddOp) .def("complete_add_op", &NetOp::CompleteAddOp)
.def("complete_add_op", .def("complete_add_op",
[](std::shared_ptr<pd::NetOp>& self) { self->CompleteAddOp(); }); [](std::shared_ptr<NetOp> &self) { self->CompleteAddOp(); });
ExposeOperator(net); ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr(); return m.ptr();
} }
} // namespace framework
} // namespace paddle
...@@ -26,19 +26,17 @@ limitations under the License. */ ...@@ -26,19 +26,17 @@ limitations under the License. */
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace pybind {
namespace details { // forward declare
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
} // namespace details
} // namespace pybind
namespace framework { namespace framework {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
class Tensor { class Tensor {
public: public:
template <bool less, size_t i, typename... args> template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl; friend struct details::CastToPyBufferImpl;
template <typename T, size_t D, int MajorType, typename IndexType> template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
......
...@@ -21,7 +21,7 @@ namespace py = pybind11; ...@@ -21,7 +21,7 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace framework {
namespace details { namespace details {
...@@ -59,11 +59,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -59,11 +59,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
return py::buffer_info( return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
py::format_descriptor<CUR_TYPE>::format(), (size_t)framework::arity(tensor.dims()), dims_outside, strides);
(size_t)framework::arity(tensor.dims()),
dims_outside,
strides);
} else { } else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value; constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor); return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
......
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_network_op)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册