提交 1dc53a28 编写于 作者: Y Yu Yang

Use friend not to expose tensor's `type/place`

上级 a89c7ffa
...@@ -24,6 +24,12 @@ limitations under the License. */ ...@@ -24,6 +24,12 @@ limitations under the License. */
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
namespace paddle { namespace paddle {
namespace pybind {
namespace details { // forward declare
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
} // namespace details
} // namespace pybind
namespace framework { namespace framework {
class Tensor { class Tensor {
...@@ -128,10 +134,6 @@ class Tensor { ...@@ -128,10 +134,6 @@ class Tensor {
DDim dims() const { return dims_; } DDim dims() const { return dims_; }
platform::Place place() const { return holder_->place(); }
std::type_index type() const { return holder_->type(); }
private: private:
// Placeholder hides type T, so it doesn't appear as a template // Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable. // parameter of Variable.
...@@ -186,6 +188,8 @@ class Tensor { ...@@ -186,6 +188,8 @@ class Tensor {
DDim dims_; DDim dims_;
size_t numel_; // cache of `product(dims_)` size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area. size_t offset_; // marks the begin of tensor data area.
template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl;
}; // namespace framework }; // namespace framework
} // namespace framework } // namespace framework
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/scope.h> #include <paddle/framework/scope.h>
#include <paddle/pybind/tensor.h> #include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
...@@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) { ...@@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) {
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol()) py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer([](pd::Tensor& self) -> py::buffer_info { .def_buffer([](pd::Tensor& self) -> py::buffer_info {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(self.place()),
"Only CPU tensor can cast to numpy array");
return paddle::pybind::CastToPyBuffer(self); return paddle::pybind::CastToPyBuffer(self);
}) })
.def("get_dims", .def("get_dims",
......
...@@ -40,7 +40,10 @@ template <size_t I, typename... ARGS> ...@@ -40,7 +40,10 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) { PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
std::vector<size_t> strides; std::vector<size_t> strides;
...@@ -54,7 +57,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -54,7 +57,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
return py::buffer_info(tensor.mutable_data<CUR_TYPE>(tensor.place()), return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()), (size_t)framework::arity(tensor.dims()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册