提交 d5109130 编写于 作者: Q qijun

set default cpu place for tensor alloc

上级 e2ba1337
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
...@@ -104,15 +105,21 @@ class Tensor { ...@@ -104,15 +105,21 @@ class Tensor {
template <typename T> template <typename T>
void CopyFrom(const Tensor& src, platform::Place dst_place) { void CopyFrom(const Tensor& src, platform::Place dst_place) {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && PADDLE_ENFORCE(platform::is_cpu_place(dst_place),
platform::is_cpu_place(dst_place), "Tensor::CopyFrom only support dst CPU now.");
"Tensor::CopyFrom only support CPU now.");
src.EnforceSufficientMemory<T>();
size_t size = product(src.dims_) * sizeof(T); size_t size = product(src.dims_) * sizeof(T);
Resize(src.dims()); Resize(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>()); const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place)); void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size); if (paddle::platform::is_cpu_place(holder_->place())) {
std::memcpy(dst_ptr, src_ptr, size);
} else if (paddle::platform::is_gpu_place(holder_->place())) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
GpuMemcpySync(dst_ptr, src_ptr, size, cudaMemcpyDeviceToHost);
#endif
}
} }
template <typename T> template <typename T>
......
...@@ -66,10 +66,18 @@ PYBIND11_PLUGIN(core) { ...@@ -66,10 +66,18 @@ PYBIND11_PLUGIN(core) {
[](pd::Tensor& self, paddle::platform::Place& place) { [](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("alloc_float",
[](pd::Tensor& self) {
self.mutable_data<float>(paddle::platform::CPUPlace());
})
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self, paddle::platform::Place& place) { [](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("alloc_int",
[](pd::Tensor& self) {
self.mutable_data<int>(paddle::platform::CPUPlace());
})
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>) .def("set", paddle::pybind::PyTensorSetFromArray<int>)
.def("shape", .def("shape",
......
...@@ -57,11 +57,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -57,11 +57,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod; strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom(tensor, platform::CPUPlace());
} else if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info( return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()), (size_t)framework::arity(dst_tensor.dims()),
dims_outside, dims_outside,
strides); strides);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册