提交 aa5ca8a9 编写于 作者: Q qijun

fix build error

上级 d5109130
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h" #include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
...@@ -131,18 +132,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -131,18 +132,24 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static( .def_static("cpu_context",
"create", []() -> paddle::platform::DeviceContext* {
[](paddle::platform::Place) -> paddle::platform::DeviceContext* {
if (paddle::platform::is_gpu_place(place)) {
return new paddle::platform::GPUDeviceContext(place);
} else if (paddle::platform::is_cpu_place(place)) {
return new paddle::platform::CPUDeviceContext(); return new paddle::platform::CPUDeviceContext();
} })
.def_static("gpu_context",
[](paddle::platform::Place& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
// PADDLE_THROW("'GPUPlace' is not supported in CPU only
// device.");
return nullptr;
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
}); });
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>()); py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
.def(py::init<>());
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>()); py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>());
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
...@@ -57,9 +58,9 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -57,9 +58,9 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod; strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
Tensor dst_tensor; framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) { if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom(tensor, platform::CPUPlace()); dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_gpu_place(tensor.holder_->place())) { } else if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor = tensor; dst_tensor = tensor;
} }
...@@ -96,20 +97,13 @@ void PyTensorSetFromArray( ...@@ -96,20 +97,13 @@ void PyTensorSetFromArray(
auto *dst = self.mutable_data<T>(self.place()); auto *dst = self.mutable_data<T>(self.place());
if (paddle::platform::is_cpu_place(self.place())) { if (paddle::platform::is_cpu_place(self.place())) {
paddle::memory::Copy<paddle::platform::CPUPlace, std::memcpy(dst, array.data(), sizeof(T) * array.size());
paddle::platform::CPUPlace>( } else if (paddle::platform::is_gpu_place(self.place())) {
place, dst, place, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU #ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else #else
paddle::memory::Copy<paddle::platform::GPUPlace, GpuMemcpySync(
paddle::platform::CPUPlace>( dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
place,
dst,
paddle::platform::CPUPlace(),
array.data(),
sizeof(T) * array.size());
#endif #endif
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册