提交 aa5ca8a9 编写于 作者: Q qijun

fix build error

上级 d5109130
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
......@@ -131,18 +132,24 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static(
"create",
[](paddle::platform::Place) -> paddle::platform::DeviceContext* {
if (paddle::platform::is_gpu_place(place)) {
return new paddle::platform::GPUDeviceContext(place);
} else if (paddle::platform::is_cpu_place(place)) {
.def_static("cpu_context",
[]() -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
}
})
.def_static("gpu_context",
[](paddle::platform::Place& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
// PADDLE_THROW("'GPUPlace' is not supported in CPU only
// device.");
return nullptr;
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
});
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
.def(py::init<>());
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>());
......
......@@ -13,6 +13,7 @@
limitations under the License. */
#pragma once
#include <string>
#include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
......@@ -57,9 +58,9 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}
Tensor dst_tensor;
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom(tensor, platform::CPUPlace());
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
......@@ -96,20 +97,13 @@ void PyTensorSetFromArray(
auto *dst = self.mutable_data<T>(self.place());
if (paddle::platform::is_cpu_place(self.place())) {
paddle::memory::Copy<paddle::platform::CPUPlace,
paddle::platform::CPUPlace>(
place, dst, place, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(place)) {
std::memcpy(dst, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(self.place())) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
paddle::memory::Copy<paddle::platform::GPUPlace,
paddle::platform::CPUPlace>(
place,
dst,
paddle::platform::CPUPlace(),
array.data(),
sizeof(T) * array.size());
GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
#endif
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册