提交 1dc53a28 编写于 作者: Y Yu Yang

Use friend not to expose tensor's `type/place`

上级 a89c7ffa
......@@ -24,6 +24,12 @@ limitations under the License. */
#include "paddle/platform/place.h"
namespace paddle {
namespace pybind {
namespace details { // forward declare
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
} // namespace details
} // namespace pybind
namespace framework {
class Tensor {
......@@ -128,10 +134,6 @@ class Tensor {
DDim dims() const { return dims_; }
platform::Place place() const { return holder_->place(); }
std::type_index type() const { return holder_->type(); }
private:
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
......@@ -186,7 +188,9 @@ class Tensor {
DDim dims_;
size_t numel_; // cache of `product(dims_)`
size_t offset_; // marks the begin of tensor data area.
}; // namespace framework
template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl;
}; // namespace framework
} // namespace framework
} // namespace paddle
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <Python.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/scope.h>
#include <paddle/pybind/tensor.h>
#include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
......@@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) {
py::class_<pd::Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer([](pd::Tensor& self) -> py::buffer_info {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(self.place()),
"Only CPU tensor can cast to numpy array");
return paddle::pybind::CastToPyBuffer(self);
})
.def("get_dims",
......
......@@ -40,7 +40,10 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) {
if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside;
std::vector<size_t> strides;
......@@ -54,12 +57,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1];
}
return py::buffer_info(tensor.mutable_data<CUR_TYPE>(tensor.place()),
sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()),
dims_outside,
strides);
return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()),
dims_outside,
strides);
} else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册