diff --git a/paddle/platform/cuda_device.h b/paddle/platform/cuda_device_context.h similarity index 94% rename from paddle/platform/cuda_device.h rename to paddle/platform/cuda_device_context.h index cbb69d1cc556696275a68e76d26d234221ab0212..420159fb2c610f29f60e3b0a11b61e47c13055dc 100644 --- a/paddle/platform/cuda_device.h +++ b/paddle/platform/cuda_device_context.h @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU -#include "paddle/platform/device.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -29,6 +28,13 @@ using DEVICE_GPU = Eigen::GpuDevice; namespace paddle { namespace platform { +class CUDADeviceContext; + +template <> +DEVICE_GPU DeviceContext::get_eigen_device() { + return static_cast(this)->eigen_handle(); +} + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { @@ -43,8 +49,7 @@ class GPUPlaceGuard { GPUPlace previous_; }; -template <> -class Device { +class CUDADeviceContext : public DeviceContext { public: explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); @@ -61,7 +66,7 @@ class Device { cudaStream_t stream() { return stream_; } - DEVICE_GPU eigen_device() { return *eigen_device_; } + Eigen::GpuDevice eigen_device() { return *eigen_device_; } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -139,7 +144,7 @@ class Device { cudaStream_t stream_; Eigen::CudaStreamDevice* eigen_stream_; - DEVICE_GPU* eigen_device_; + Eigen::GpuDevice* eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/cuda_device_test.cc b/paddle/platform/cuda_device_test.cc deleted file mode 100644 index ea647be8760c8f22a3beaa4c9b6c70dde84328f3..0000000000000000000000000000000000000000 --- a/paddle/platform/cuda_device_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/platform/cuda_device.h" -#include "gtest/gtest.h" - -TEST(Device, Init) { - int count = paddle::platform::GetDeviceCount(); - for (int i = 0; i < count; i++) { - paddle::platform::Device* device = - new paddle::platform::Device(i); - Eigen::GpuDevice gpu_device = device->eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); - cudnnHandle_t cudnn_handle = device->cudnn_handle(); - ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device->curand_generator(); - ASSERT_NE(nullptr, curand_handle); - delete device; - } -} diff --git a/paddle/platform/device.h b/paddle/platform/device.h deleted file mode 100644 index 9ae41cbcb06782cbdecaf9a0f1aba62b69cb54f1..0000000000000000000000000000000000000000 --- a/paddle/platform/device.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "unsupported/Eigen/CXX11/Tensor" - -using DEVICE_CPU = Eigen::DefaultDevice; - -namespace paddle { -namespace platform { - -template -class Device; - -template <> -class Device { - public: - DEVICE_CPU eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); - } - return *eigen_handle_; - } - - private: - DEVICE_CPU* eigen_handle_{nullptr}; -}; - -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 8b0bac62802ea1c79e59c1766586e2f1cd0897b4..11a05702cd9876c98e730ccaaede2fb04696254e 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -13,23 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/enforce.h" -#include "paddle/platform/device.h" #include "unsupported/Eigen/CXX11/Tensor" -#ifndef PADDLE_ONLY_CPU -#include "paddle/platform/cuda_device.h" -#endif + +using DEVICE_CPU = Eigen::DefaultDevice; namespace paddle { namespace platform { -struct DeviceContext { - void* device_context{nullptr}; +class CPUDeviceContext; + +class DeviceContext { + public: + virtual ~DeviceContext() {} template - inline paddle::platform::Device* device_context() { - return static_cast*>(device_context); + DeviceType get_eigen_device(); +}; + +template <> +DEVICE_CPU DeviceContext::get_eigen_device() { + return static_cast(this)->eigen_handle(); +} + +class CPUDeviceContext : public DeviceContext { + public: + Eigen::DefaultDevice eigen_handle() { + if (!eigen_handle_) { + eigen_handle_ = new Eigen::DefaultDevice(); + } + return *eigen_handle_; } + + private: + Eigen::DefaultDevice* eigen_handle_{nullptr}; }; } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index ab8a6d819593c3a731dcb716a81b1bd419e49e45..8390e97b15f174bf7e101d73ce97e575f08a1ac3 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,19 +12,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/platform/device_context.h" #include "gtest/gtest.h" +#include "paddle/platform/cuda_device.h" -TEST(DeviceContext, Init) { +TEST(Device, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { - paddle::platform::Device* device = - new paddle::platform::Device(i); - paddle::platform::DeviceContext context; - context.device_context = device; + paddle::platform::DeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); Eigen::GpuDevice gpu_device = - context.device_context->eigen_device(); + device_context->get_eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); - delete device; + delete device_context; } -} \ No newline at end of file +} + +TEST(Device, CUDADeviceContext) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::CUDADeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); + Eigen::GpuDevice gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device.stream()); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); + curandGenerator_t curand_handle = device_context->curand_generator(); + ASSERT_NE(nullptr, curand_handle); + delete device_context; + } +}