From 8ee50a35d408634c817d3da849a15217e57dcba1 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 07:50:08 +0000 Subject: [PATCH] fix gpu build error --- paddle/platform/CMakeLists.txt | 1 - paddle/platform/cuda_device_context.h | 15 +++++++-------- paddle/platform/device_context.h | 24 +++++++++++------------- paddle/platform/device_context_test.cc | 5 +++-- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index d40e49b546e..e93592cc4cb 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index c38dcd5a615..8a9d15e8a84 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -20,19 +20,13 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU +#include "paddle/platform/device_context.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { -class CUDADeviceContext; - -template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { @@ -49,7 +43,7 @@ class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext { public: - explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext { int random_seed_; curandGenerator_t rand_generator_{nullptr}; }; + +template <> +Eigen::GpuDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d2a51699917..d2f7cf62166 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,30 +20,23 @@ limitations under the License. */ namespace paddle { namespace platform { -class CPUDeviceContext; - class DeviceContext { public: virtual ~DeviceContext() {} template - DeviceType get_eigen_device(); + inline DeviceType get_eigen_device(); virtual Place GetPlace() const = 0; }; -template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); + Eigen::DefaultDevice eigen_device() { + if (!eigen_device_) { + eigen_device_ = new Eigen::DefaultDevice(); } - return *eigen_handle_; + return *eigen_device_; } Place GetPlace() const override { @@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext { } private: - Eigen::DefaultDevice* eigen_handle_{nullptr}; + Eigen::DefaultDevice* eigen_device_{nullptr}; }; + +template <> +Eigen::DefaultDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8390e97b15f..abaaaececf5 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" -#include "paddle/platform/cuda_device.h" +#include "paddle/platform/cuda_device_context.h" +using DEVICE_GPU = Eigen::GpuDevice; TEST(Device, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { paddle::platform::DeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); Eigen::GpuDevice gpu_device = - device_context->get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); delete device_context; } -- GitLab