diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index e33070c40fbfa7f2794426247ef77b8fcaee4ec6..7852bb53a9035f71f52a51529c8e3cea22b0d4aa 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -274,7 +274,7 @@ void set_constant_with_place( } template <> -void set_constant_with_place( +void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { set_constant_with_place(context, tensor, value); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 1c72b5055971e73c7aa560a61ca9d3c48dc56fbc..8cdc5f43403b0c54d3f1f01a3e97405fd5b2f434 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) +CUDNNDeviceContext::CUDNNDeviceContext(CUDNNPlace place) : CUDADeviceContext(place), place_(place) { PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); } -CudnnDeviceContext::~CudnnDeviceContext() { +CUDNNDeviceContext::~CUDNNDeviceContext() { SetDeviceId(place_.device); Wait(); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } -Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } +Place CUDNNDeviceContext::GetPlace() const { return CUDNNPlace(); } -cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } +cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f67194993db1f4160bd6894b2c845a82f4da2354..56813a1d5b3c2a7f4ff7b4eddc6fa47ed861700c 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; -class CudnnDeviceContext : public CUDADeviceContext { +class CUDNNDeviceContext : public CUDADeviceContext { public: - explicit CudnnDeviceContext(CudnnPlace place); - virtual ~CudnnDeviceContext(); + explicit CUDNNDeviceContext(CUDNNPlace place); + virtual ~CUDNNDeviceContext(); /*! \brief Return place in the device context. */ Place GetPlace() const final; @@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext { private: cudnnHandle_t cudnn_handle_; - CudnnPlace place_; + CUDNNPlace place_; }; #endif diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index be3b2af5af09cb18f5156412ff60a7fc15a16487..109c13a8812dffac10d202cbc9d85c4e601bf197 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) { } } -TEST(Device, CudnnDeviceContext) { - using paddle::platform::CudnnDeviceContext; - using paddle::platform::CudnnPlace; +TEST(Device, CUDNNDeviceContext) { + using paddle::platform::CUDNNDeviceContext; + using paddle::platform::CUDNNPlace; if (paddle::platform::dynload::HasCUDNN()) { int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; ++i) { - CudnnDeviceContext* device_context = - new CudnnDeviceContext(CudnnPlace(i)); + CUDNNDeviceContext* device_context = + new CUDNNDeviceContext(CUDNNPlace(i)); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, device_context->stream()); diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 4526945792b2ea96cc4e9df11d8f35897cba7526..ca98920d414bc87ce243995a42e5672d0e61e108 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -51,9 +51,9 @@ struct GPUPlace { int device; }; -struct CudnnPlace : public GPUPlace { - CudnnPlace() : GPUPlace() {} - explicit CudnnPlace(int d) : GPUPlace(d) {} +struct CUDNNPlace : public GPUPlace { + CUDNNPlace() : GPUPlace() {} + explicit CUDNNPlace(int d) : GPUPlace(d) {} }; struct IsGPUPlace : public boost::static_visitor { @@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor { // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc index 184af12c230f1ccd7826e507f16f4e91ca380a45..c536b59ed8f71bd078bd09c5bd5afeab74c71b28 100644 --- a/paddle/platform/place_test.cc +++ b/paddle/platform/place_test.cc @@ -5,16 +5,22 @@ TEST(Place, Equality) { paddle::platform::CPUPlace cpu; paddle::platform::GPUPlace g0(0), g1(1), gg0(0); + paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0); EXPECT_EQ(cpu, cpu); EXPECT_EQ(g0, g0); EXPECT_EQ(g1, g1); EXPECT_EQ(g0, gg0); + EXPECT_EQ(d0, dd0); EXPECT_NE(g0, g1); + EXPECT_NE(d0, d1); EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0)); EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu)); + + EXPECT_TRUE(paddle::platform::is_gpu_place(d0)); + EXPECT_FALSE(paddle::platform::places_are_same_class(g0, d0)); } TEST(Place, Default) {