diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 1b560a7e2d29c1b63a25d4ec9bbd82d5960a279d..e33070c40fbfa7f2794426247ef77b8fcaee4ec6 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -273,6 +273,13 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + set_constant_with_place(context, tensor, value); +} + template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 2c7f96421621b9a34d1ec96c13d9c354a0d4012c..1c72b5055971e73c7aa560a61ca9d3c48dc56fbc 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) + : CUDADeviceContext(place), place_(place) { + PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); +} + +CudnnDeviceContext::~CudnnDeviceContext() { + SetDeviceId(place_.device); + Wait(); + PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); +} + +Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } + +cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } + #endif } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 596d9d0bba420a47fc10cc9dd96a755daa35dbac..f67194993db1f4160bd6894b2c845a82f4da2354 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +class CudnnDeviceContext : public CUDADeviceContext { + public: + explicit CudnnDeviceContext(CudnnPlace place); + virtual ~CudnnDeviceContext(); + + /*! \brief Return place in the device context. */ + Place GetPlace() const final; + + /*! \brief Return cudnn handle in the device context. */ + cudnnHandle_t cudnn_handle() const; + + private: + cudnnHandle_t cudnn_handle_; + CudnnPlace place_; +}; + #endif } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 4893cd92f6a74f7992c279ebd51232049f29e853..be3b2af5af09cb18f5156412ff60a7fc15a16487 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) { delete device_context; } } + +TEST(Device, CudnnDeviceContext) { + using paddle::platform::CudnnDeviceContext; + using paddle::platform::CudnnPlace; + if (paddle::platform::dynload::HasCUDNN()) { + int count = paddle::platform::GetCUDADeviceCount(); + for (int i = 0; i < count; ++i) { + CudnnDeviceContext* device_context = + new CudnnDeviceContext(CudnnPlace(i)); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + ASSERT_NE(nullptr, device_context->stream()); + delete device_context; + } + } +} diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 5370360a7de26e409a1545182a12d3df1f37658b..f0dcec8f523fb22c2dd046113b6a8f8a0b6d916d 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -43,6 +43,11 @@ struct GPUPlace { int device; }; +struct CudnnPlace : public GPUPlace { + CudnnPlace() : GPUPlace() {} + explicit CudnnPlace(int d) : GPUPlace(d) {} +}; + struct IsGPUPlace : public boost::static_visitor { bool operator()(const CPUPlace &) const { return false; } bool operator()(const GPUPlace &gpu) const { return true; } @@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor { // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)