提交 d5109130 编写于 作者: Q qijun

set default cpu place for tensor alloc

上级 e2ba1337
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <memory>
#include <typeindex>
#include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
......@@ -104,15 +105,21 @@ class Tensor {
template <typename T>
void CopyFrom(const Tensor& src, platform::Place dst_place) {
PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) &&
platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support CPU now.");
src.EnforceSufficientMemory<T>();
PADDLE_ENFORCE(platform::is_cpu_place(dst_place),
"Tensor::CopyFrom only support dst CPU now.");
size_t size = product(src.dims_) * sizeof(T);
Resize(src.dims());
const void* src_ptr = static_cast<const void*>(src.data<T>());
void* dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
memcpy(dst_ptr, src_ptr, size);
if (paddle::platform::is_cpu_place(holder_->place())) {
std::memcpy(dst_ptr, src_ptr, size);
} else if (paddle::platform::is_gpu_place(holder_->place())) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
GpuMemcpySync(dst_ptr, src_ptr, size, cudaMemcpyDeviceToHost);
#endif
}
}
template <typename T>
......
......@@ -66,10 +66,18 @@ PYBIND11_PLUGIN(core) {
[](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<float>(place);
})
.def("alloc_float",
[](pd::Tensor& self) {
self.mutable_data<float>(paddle::platform::CPUPlace());
})
.def("alloc_int",
[](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<int>(place);
})
.def("alloc_int",
[](pd::Tensor& self) {
self.mutable_data<int>(paddle::platform::CPUPlace());
})
.def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>)
.def("shape",
......
......@@ -57,11 +57,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}
Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom(tensor, platform::CPUPlace());
} else if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()),
(size_t)framework::arity(dst_tensor.dims()),
dims_outside,
strides);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册