提交 8e4e155c 编写于 作者: C chengduoZH

add PyCUDAPinnedTensorSetFromArray

上级 6af17835
......@@ -128,13 +128,21 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place)) {
} else if (platform::is_gpu_place(place) ||
platform::is_cuda_pinned_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
PADDLE_THROW(
"'CUDAPlace' or 'CUDAPinnedPlace' is not supported in CPU only "
"device.");
}
#else
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type));
if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type));
} else if (platform::is_cuda_pinned_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CUDAPinnedPlace>(
boost::get<platform::CUDAPinnedPlace>(place), size, type));
}
}
#endif
offset_ = 0;
......
......@@ -125,6 +125,12 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<float>)
.def("set", PyCUDAPinnedTensorSetFromArray<int>)
.def("set", PyCUDAPinnedTensorSetFromArray<double>)
.def("set", PyCUDAPinnedTensorSetFromArray<int64_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<bool>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
#endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>)
......@@ -367,8 +373,8 @@ All parameter, weight, gradient are variables in Paddle.
self = gpu_place;
})
.def("set_place", [](platform::Place &self,
const platform::CUDAPinnedPlace &gpu_place) {
self = gpu_place;
const platform::CUDAPinnedPlace &cuda_pinned_place) {
self = cuda_pinned_place;
});
py::class_<OperatorBase>(m, "Operator")
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -208,6 +210,38 @@ void PyCUDATensorSetFromArray(
sizeof(uint16_t) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream());
}
template <typename T>
void PyCUDAPinnedTensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
}
template <>
void PyCUDAPinnedTensorSetFromArray(
framework::Tensor &self,
py::array_t<uint16_t, py::array::c_style | py::array::forcecast> array,
const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i]));
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<platform::float16>(place);
std::memcpy(dst, array.data(), sizeof(uint16_t) * array.size());
}
#endif
} // namespace pybind
......
......@@ -31,7 +31,7 @@ import regularizer
import average
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace
from distribute_transpiler import DistributeTranspiler
from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv,
......@@ -57,6 +57,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'LoDTensor',
'CPUPlace',
'CUDAPlace',
'CUDAPinnedPlace',
'Tensor',
'ParamAttr',
'WeightNormParamAttr',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册