提交 ae9378f6 编写于 作者: Y Yu Yang

Refine PyBind

上级 a1a01899
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include <cmake-build-release/third_party/pybind/src/extern_pybind/include/pybind11/common.h>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -57,7 +58,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -57,7 +58,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
framework::Tensor dst_tensor; framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.place())) { bool is_gpu = paddle::platform::is_gpu_place(tensor.place());
if (is_gpu) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>()); auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>());
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>( auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
...@@ -74,16 +76,44 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -74,16 +76,44 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
dst_tensor = tensor; dst_tensor = tensor;
} }
if (std::type_index(typeid(CUR_TYPE)) == std::string dtype = std::type_index(typeid(CUR_TYPE)) ==
std::type_index(typeid(platform::float16))) { std::type_index(typeid(platform::float16))
return pybind11::buffer_info( ? std::string("e") // np.dtype('e') == np.float16
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), : pybind11::format_descriptor<CUR_TYPE>::format();
"e", /* np.dtype('e') == np.float16 */
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); if (is_gpu) {
// manually construct a py_buffer if is_gpu since gpu data is copied
// into CPU.
// TODO(yy): Is these following code memleak?
Py_buffer *py_buffer =
reinterpret_cast<Py_buffer *>(malloc(sizeof(Py_buffer)));
py_buffer->format = strdup(dtype.c_str());
py_buffer->itemsize = sizeof(CUR_TYPE);
py_buffer->ndim = framework::arity(dst_tensor.dims());
py_buffer->len = tensor.numel();
py_buffer->strides = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * strides.size()));
for (size_t i = 0; i < strides.size(); ++i) {
py_buffer->strides[i] = strides[i];
}
py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (size_t i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i];
}
py_buffer->readonly = false;
py_buffer->suboffsets = nullptr;
py_buffer->obj = nullptr;
py_buffer->buf =
malloc(static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
memcpy(py_buffer->buf, dst_tensor.data<CUR_TYPE>(),
static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
return pybind11::buffer_info(py_buffer, true);
} else { } else {
return pybind11::buffer_info( return pybind11::buffer_info(
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), dtype,
pybind11::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
} }
} else { } else {
......
...@@ -289,9 +289,9 @@ class TestFP16CUDNNWithGroup(TestWithGroup): ...@@ -289,9 +289,9 @@ class TestFP16CUDNNWithGroup(TestWithGroup):
self.check_output_with_place(place, atol=2e-2) self.check_output_with_place(place, atol=2e-2)
# class TestCUDNNWith1x1(TestWith1x1): class TestCUDNNWith1x1(TestWith1x1):
# def init_kernel_type(self): def init_kernel_type(self):
# self.use_cudnn = True self.use_cudnn = True
class TestFP16CUDNNWith1x1(TestWith1x1): class TestFP16CUDNNWith1x1(TestWith1x1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册