提交 8ee50a35 编写于 作者: Q qijun

fix gpu build error

上级 ca23d861
......@@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place)
nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place)
......@@ -20,19 +20,13 @@ limitations under the License. */
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#define EIGEN_USE_GPU
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
class CUDADeviceContext;
template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return static_cast<CUDADeviceContext*>(this)->eigen_handle();
}
class GPUPlaceGuard {
public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
......@@ -49,7 +43,7 @@ class GPUPlaceGuard {
class CUDADeviceContext : public DeviceContext {
public:
explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
......@@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext {
int random_seed_;
curandGenerator_t rand_generator_{nullptr};
};
template <>
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
return dynamic_cast<CUDADeviceContext*>(this)->eigen_device();
}
} // namespace platform
} // namespace paddle
......@@ -20,30 +20,23 @@ limitations under the License. */
namespace paddle {
namespace platform {
class CPUDeviceContext;
class DeviceContext {
public:
virtual ~DeviceContext() {}
template <typename DeviceType>
DeviceType get_eigen_device();
inline DeviceType get_eigen_device();
virtual Place GetPlace() const = 0;
};
template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return static_cast<CPUDeviceContext*>(this)->eigen_handle();
}
class CPUDeviceContext : public DeviceContext {
public:
Eigen::DefaultDevice eigen_handle() {
if (!eigen_handle_) {
eigen_handle_ = new Eigen::DefaultDevice();
Eigen::DefaultDevice eigen_device() {
if (!eigen_device_) {
eigen_device_ = new Eigen::DefaultDevice();
}
return *eigen_handle_;
return *eigen_device_;
}
Place GetPlace() const override {
......@@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext {
}
private:
Eigen::DefaultDevice* eigen_handle_{nullptr};
Eigen::DefaultDevice* eigen_device_{nullptr};
};
template <>
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
return dynamic_cast<CPUDeviceContext*>(this)->eigen_device();
}
} // namespace platform
} // namespace paddle
......@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/platform/cuda_device.h"
#include "paddle/platform/cuda_device_context.h"
using DEVICE_GPU = Eigen::GpuDevice;
TEST(Device, Init) {
int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context =
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice gpu_device =
device_context->get_eigen_device<DEVICE_GPU>();
device_context->template get_eigen_device<DEVICE_GPU>();
ASSERT_NE(nullptr, gpu_device.stream());
delete device_context;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册