From 0e9b393b340990cf581ec9f6e5f33af74912c0b6 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 14 Dec 2017 16:21:58 +0800 Subject: [PATCH] "derived cudnnDevice context" (#6585) * "derived cudnnDevice context" * "leave remove cudnn handle from CUDADeviceContext" * "fix math function error" --- paddle/operators/math/math_function.cu | 7 +++++++ paddle/platform/device_context.cc | 16 ++++++++++++++++ paddle/platform/device_context.h | 16 ++++++++++++++++ paddle/platform/device_context_test.cc | 16 ++++++++++++++++ paddle/platform/place.h | 7 ++++++- 5 files changed, 61 insertions(+), 1 deletion(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 1b560a7e2d2..e33070c40fb 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -273,6 +273,13 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + set_constant_with_place(context, tensor, value); +} + template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 2c7f9642162..1c72b505597 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) + : CUDADeviceContext(place), place_(place) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); +} + +CudnnDeviceContext::~CudnnDeviceContext() { + SetDeviceId(place_.device); + Wait(); + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); +} + +Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } + +cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } + #endif } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 596d9d0bba4..f67194993db 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +class CudnnDeviceContext : public CUDADeviceContext { + public: + explicit CudnnDeviceContext(CudnnPlace place); + virtual ~CudnnDeviceContext(); + + /*! \brief Return place in the device context. */ + Place GetPlace() const final; + + /*! \brief Return cudnn handle in the device context. */ + cudnnHandle_t cudnn_handle() const; + + private: + cudnnHandle_t cudnn_handle_; + CudnnPlace place_; +}; + #endif } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 4893cd92f6a..be3b2af5af0 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) { delete device_context; } } + +TEST(Device, CudnnDeviceContext) { + using paddle::platform::CudnnDeviceContext; + using paddle::platform::CudnnPlace; + if (paddle::platform::dynload::HasCUDNN()) { + int count = paddle::platform::GetCUDADeviceCount(); + for (int i = 0; i < count; ++i) { + CudnnDeviceContext* device_context = + new CudnnDeviceContext(CudnnPlace(i)); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + ASSERT_NE(nullptr, device_context->stream()); + delete device_context; + } + } +} diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 5370360a7de..f0dcec8f523 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -43,6 +43,11 @@ struct GPUPlace { int device; }; +struct CudnnPlace : public GPUPlace { + CudnnPlace() : GPUPlace() {} + explicit CudnnPlace(int d) : GPUPlace(d) {} +}; + struct IsGPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const GPUPlace &gpu) const { return true; } @@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor { // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) -- GitLab