From 1dc53a289fe724cd3772618de374aacbf72a87f6 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 18 Jul 2017 15:23:13 +0800 Subject: [PATCH] Use friend not to expose tensor's `type/place` --- paddle/framework/tensor.h | 14 +++++++++----- paddle/pybind/pybind.cc | 4 +--- paddle/pybind/{tensor.h => tensor_bind.h} | 18 +++++++++++------- 3 files changed, 21 insertions(+), 15 deletions(-) rename paddle/pybind/{tensor.h => tensor_bind.h} (84%) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 891cf73641..c495687dc4 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -24,6 +24,12 @@ limitations under the License. */ #include "paddle/platform/place.h" namespace paddle { +namespace pybind { +namespace details { // forward declare +template +struct CastToPyBufferImpl; +} // namespace details +} // namespace pybind namespace framework { class Tensor { @@ -128,10 +134,6 @@ class Tensor { DDim dims() const { return dims_; } - platform::Place place() const { return holder_->place(); } - - std::type_index type() const { return holder_->type(); } - private: // Placeholder hides type T, so it doesn't appear as a template // parameter of Variable. @@ -186,7 +188,9 @@ class Tensor { DDim dims_; size_t numel_; // cache of `product(dims_)` size_t offset_; // marks the begin of tensor data area. -}; // namespace framework + template + friend struct paddle::pybind::details::CastToPyBufferImpl; +}; // namespace framework } // namespace framework } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e3dc3e718c..0eef36f8ec 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include +#include #include #include #include @@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) { py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer([](pd::Tensor& self) -> py::buffer_info { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(self.place()), - "Only CPU tensor can cast to numpy array"); return paddle::pybind::CastToPyBuffer(self); }) .def("get_dims", diff --git a/paddle/pybind/tensor.h b/paddle/pybind/tensor_bind.h similarity index 84% rename from paddle/pybind/tensor.h rename to paddle/pybind/tensor_bind.h index ef07144ad4..b96516643a 100644 --- a/paddle/pybind/tensor.h +++ b/paddle/pybind/tensor_bind.h @@ -40,7 +40,10 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), + "Only CPU tensor can cast to numpy array"); + + if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; std::vector strides; @@ -54,12 +57,13 @@ struct CastToPyBufferImpl { prod *= dims_outside[i - 1]; } - return py::buffer_info(tensor.mutable_data(tensor.place()), - sizeof(CUR_TYPE), - py::format_descriptor::format(), - (size_t)framework::arity(tensor.dims()), - dims_outside, - strides); + return py::buffer_info( + tensor.mutable_data(tensor.holder_->place()), + sizeof(CUR_TYPE), + py::format_descriptor::format(), + (size_t)framework::arity(tensor.dims()), + dims_outside, + strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); -- GitLab