提交 ae9378f6 编写于 作者: Y Yu Yang

Refine PyBind

上级 a1a01899
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <Python.h>
#include <cmake-build-release/third_party/pybind/src/extern_pybind/include/pybind11/common.h>
#include <string>
#include <tuple>
#include <vector>
......@@ -57,7 +58,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1];
}
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.place())) {
bool is_gpu = paddle::platform::is_gpu_place(tensor.place());
if (is_gpu) {
#ifdef PADDLE_WITH_CUDA
auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>());
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
......@@ -74,16 +76,44 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
dst_tensor = tensor;
}
if (std::type_index(typeid(CUR_TYPE)) ==
std::type_index(typeid(platform::float16))) {
return pybind11::buffer_info(
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
"e", /* np.dtype('e') == np.float16 */
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
std::string dtype = std::type_index(typeid(CUR_TYPE)) ==
std::type_index(typeid(platform::float16))
? std::string("e") // np.dtype('e') == np.float16
: pybind11::format_descriptor<CUR_TYPE>::format();
if (is_gpu) {
// manually construct a py_buffer if is_gpu since gpu data is copied
// into CPU.
// TODO(yy): Is these following code memleak?
Py_buffer *py_buffer =
reinterpret_cast<Py_buffer *>(malloc(sizeof(Py_buffer)));
py_buffer->format = strdup(dtype.c_str());
py_buffer->itemsize = sizeof(CUR_TYPE);
py_buffer->ndim = framework::arity(dst_tensor.dims());
py_buffer->len = tensor.numel();
py_buffer->strides = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * strides.size()));
for (size_t i = 0; i < strides.size(); ++i) {
py_buffer->strides[i] = strides[i];
}
py_buffer->shape = reinterpret_cast<Py_ssize_t *>(
malloc(sizeof(Py_ssize_t) * tensor.dims().size()));
for (size_t i = 0; i < tensor.dims().size(); ++i) {
py_buffer->shape[i] = tensor.dims()[i];
}
py_buffer->readonly = false;
py_buffer->suboffsets = nullptr;
py_buffer->obj = nullptr;
py_buffer->buf =
malloc(static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
memcpy(py_buffer->buf, dst_tensor.data<CUR_TYPE>(),
static_cast<size_t>(py_buffer->len * py_buffer->itemsize));
return pybind11::buffer_info(py_buffer, true);
} else {
return pybind11::buffer_info(
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
pybind11::format_descriptor<CUR_TYPE>::format(),
dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), dtype,
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
}
} else {
......
......@@ -289,9 +289,9 @@ class TestFP16CUDNNWithGroup(TestWithGroup):
self.check_output_with_place(place, atol=2e-2)
# class TestCUDNNWith1x1(TestWith1x1):
# def init_kernel_type(self):
# self.use_cudnn = True
class TestCUDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNWith1x1(TestWith1x1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册