diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index c7d7b14518ebb8415014a78fc1a3bafa8c386191..c95b54a4dfa61f59e3cdec842037df684dfa5e12 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -2,3 +2,6 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) + +cc_library(dynamic_loader SRCS dynamic_loader.cc) +nv_test(device_context_test SRCS device_context_test.cu DEPS place dynamic_loader glog gflags) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..f95aac4a36034899ac3a209bbfc5292ac28d3526 --- /dev/null +++ b/paddle/platform/device_context.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef PADDLE_ONLY_CPU +#include "paddle/platform/cublas.h" +#include "paddle/platform/cuda.h" +#include "paddle/platform/cudnn.h" +#include "paddle/platform/curand.h" +#define EIGEN_USE_GPU +#endif + +#include "paddle/framework/enforce.h" +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace platform { + +class DeviceContext { + public: + virtual ~DeviceContext() {} +}; + +class CpuDeviceContext : public DeviceContext { + Eigen::DefaultDevice eigen_device() { + if (!eigen_device_) { + eigen_device_ = new Eigen::DefaultDevice(); + } + return *eigen_device_; + } + + private: + Eigen::DefaultDevice* eigen_device_{nullptr}; +}; + +#ifndef PADDLE_ONLY_CPU +class DeviceGuard { + public: + explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { + if (previous_ != new_place) { + paddle::platform::SetDeviceId(new_place.device); + } + } + + ~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); } + + private: + GPUPlace previous_; +}; + +class CudaDeviceContext : public DeviceContext { + public: + explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + DeviceGuard guard(gpu_place_); + paddle::platform::throw_on_error(cudaStreamCreate(&stream_), + "cudaStreamCreate failed"); + eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); + eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + } + + void Wait() { + paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); + } + + cudaStream_t stream() { return stream_; } + + Eigen::GpuDevice eigen_device() { return *eigen_device_; } + + cublasHandle_t cublas_handle() { + if (!blas_handle_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE( + cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); + } + return blas_handle_; + } + + cudnnHandle_t cudnn_handle() { + if (!dnn_handle_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE( + cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); + } + return dnn_handle_; + } + + curandGenerator_t curand_generator() { + if (!rand_generator_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE( + curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); + PADDLE_ENFORCE( + curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) == + CURAND_STATUS_SUCCESS, + "curandSetPseudoRandomGeneratorSeed failed"); + PADDLE_ENFORCE( + curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); + } + return rand_generator_; + } + + ~CudaDeviceContext() { + Wait(); + if (blas_handle_) { + PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); + } + + if (dnn_handle_) { + PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); + } + + if (rand_generator_) { + PADDLE_ENFORCE( + curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); + } + + delete eigen_stream_; + delete eigen_device_; + + paddle::platform::throw_on_error(cudaStreamDestroy(stream_), + "cudaStreamDestroy failed"); + } + + private: + GPUPlace gpu_place_; + cudaStream_t stream_; + + Eigen::CudaStreamDevice* eigen_stream_; + Eigen::GpuDevice* eigen_device_; + + cublasHandle_t blas_handle_{nullptr}; + + cudnnHandle_t dnn_handle_{nullptr}; + + int random_seed_; + curandGenerator_t rand_generator_{nullptr}; +}; +#endif +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..a15fb53b719b33844d70c79be3dbfbcb93ad841a --- /dev/null +++ b/paddle/platform/device_context_test.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/device_context.h" +#include "gtest/gtest.h" + + +TEST(DeviceContext, CudaDevice) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::CudaDeviceContext* device_context = new paddle::platform::CudaDeviceContext(i); + __attribute__((unused)) Eigen::GpuDevice gpu_device = device_context->eigen_device(); + __attribute__((unused)) cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + __attribute__((unused)) cublasHandle_t cublas_handle = device_context->cublas_handle(); + __attribute__((unused)) curandGenerator_t curand_handle = device_context->curand_generator(); + delete device_context; + } +}