From ef5f9debc61ce4f6b3142fedbf85a118a34731eb Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 13:51:04 +0800 Subject: [PATCH] refine device_context --- paddle/platform/CMakeLists.txt | 1 + .../{cuda_device_context.h => cuda_device.h} | 13 +++--- paddle/platform/cuda_device_test.cc | 33 +++++++++++++++ paddle/platform/device.h | 41 +++++++++++++++++++ paddle/platform/device_context.h | 23 ++++------- paddle/platform/device_context_test.cc | 23 +++++------ 6 files changed, 102 insertions(+), 32 deletions(-) rename paddle/platform/{cuda_device_context.h => cuda_device.h} (94%) create mode 100644 paddle/platform/cuda_device_test.cc create mode 100644 paddle/platform/device.h diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index e93592cc4c..d40e49b546 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,4 +5,5 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) +nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device.h similarity index 94% rename from paddle/platform/cuda_device_context.h rename to paddle/platform/cuda_device.h index 69415fe615..cbb69d1cc5 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device.h @@ -20,10 +20,12 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU -#include "paddle/platform/device_context.h" +#include "paddle/platform/device.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" +using DEVICE_GPU = Eigen::GpuDevice; + namespace paddle { namespace platform { @@ -41,9 +43,10 @@ class GPUPlaceGuard { GPUPlace previous_; }; -class CUDADeviceContext : public DeviceContext { +template <> +class Device { public: - explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -58,7 +61,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice eigen_device() { return *eigen_device_; } + DEVICE_GPU eigen_device() { return *eigen_device_; } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -136,7 +139,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; + DEVICE_GPU* eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/cuda_device_test.cc b/paddle/platform/cuda_device_test.cc new file mode 100644 index 0000000000..ea647be876 --- /dev/null +++ b/paddle/platform/cuda_device_test.cc @@ -0,0 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/cuda_device.h" +#include "gtest/gtest.h" + +TEST(Device, Init) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::Device* device = + new paddle::platform::Device(i); + Eigen::GpuDevice gpu_device = device->eigen_device(); + ASSERT_NE(nullptr, gpu_device.stream()); + cudnnHandle_t cudnn_handle = device->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); + curandGenerator_t curand_handle = device->curand_generator(); + ASSERT_NE(nullptr, curand_handle); + delete device; + } +} diff --git a/paddle/platform/device.h b/paddle/platform/device.h new file mode 100644 index 0000000000..9ae41cbcb0 --- /dev/null +++ b/paddle/platform/device.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "unsupported/Eigen/CXX11/Tensor" + +using DEVICE_CPU = Eigen::DefaultDevice; + +namespace paddle { +namespace platform { + +template +class Device; + +template <> +class Device { + public: + DEVICE_CPU eigen_handle() { + if (!eigen_handle_) { + eigen_handle_ = new Eigen::DefaultDevice(); + } + return *eigen_handle_; + } + + private: + DEVICE_CPU* eigen_handle_{nullptr}; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f30c147126..8b0bac6280 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -14,27 +14,22 @@ limitations under the License. */ #pragma once #include "paddle/framework/enforce.h" +#include "paddle/platform/device.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifndef PADDLE_ONLY_CPU +#include "paddle/platform/cuda_device.h" +#endif namespace paddle { namespace platform { -class DeviceContext { - public: - virtual ~DeviceContext() {} -}; +struct DeviceContext { + void* device_context{nullptr}; -class CPUDeviceContext : public DeviceContext { - public: - Eigen::DefaultDevice eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); - } - return *eigen_handle_; + template + inline paddle::platform::Device* device_context() { + return static_cast*>(device_context); } - - private: - Eigen::DefaultDevice* eigen_handle_{nullptr}; }; } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index cc81e9e789..ab8a6d8195 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,22 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/platform/device_context.h" #include "gtest/gtest.h" -#include "paddle/platform/cuda_device_context.h" -TEST(CUDADeviceContext, Init) { +TEST(DeviceContext, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { - paddle::platform::CUDADeviceContext* device_context = - new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = device_context->eigen_device(); + paddle::platform::Device* device = + new paddle::platform::Device(i); + paddle::platform::DeviceContext context; + context.device_context = device; + Eigen::GpuDevice gpu_device = + context.device_context->eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); - cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); - ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); - delete device_context; + delete device; } -} +} \ No newline at end of file -- GitLab