未验证 提交 93a2d9c5 编写于 作者: Q QI JUN 提交者: GitHub

add more place test and rename Cudnn to CUDNN (#6621)

* add more place_test and rename Cudnn to CUDNN

* fix ci
上级 77cf7d4f
...@@ -274,7 +274,7 @@ void set_constant_with_place<platform::GPUPlace>( ...@@ -274,7 +274,7 @@ void set_constant_with_place<platform::GPUPlace>(
} }
template <> template <>
void set_constant_with_place<platform::CudnnPlace>( void set_constant_with_place<platform::CUDNNPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor, const platform::DeviceContext& context, framework::Tensor* tensor,
float value) { float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value); set_constant_with_place<platform::GPUPlace>(context, tensor, value);
......
...@@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } ...@@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) CUDNNDeviceContext::CUDNNDeviceContext(CUDNNPlace place)
: CUDADeviceContext(place), place_(place) { : CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
} }
CudnnDeviceContext::~CudnnDeviceContext() { CUDNNDeviceContext::~CUDNNDeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
} }
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } Place CUDNNDeviceContext::GetPlace() const { return CUDNNPlace(); }
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif #endif
......
...@@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
}; };
class CudnnDeviceContext : public CUDADeviceContext { class CUDNNDeviceContext : public CUDADeviceContext {
public: public:
explicit CudnnDeviceContext(CudnnPlace place); explicit CUDNNDeviceContext(CUDNNPlace place);
virtual ~CudnnDeviceContext(); virtual ~CUDNNDeviceContext();
/*! \brief Return place in the device context. */ /*! \brief Return place in the device context. */
Place GetPlace() const final; Place GetPlace() const final;
...@@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext { ...@@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext {
private: private:
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
CudnnPlace place_; CUDNNPlace place_;
}; };
#endif #endif
......
...@@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) { ...@@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) {
} }
} }
TEST(Device, CudnnDeviceContext) { TEST(Device, CUDNNDeviceContext) {
using paddle::platform::CudnnDeviceContext; using paddle::platform::CUDNNDeviceContext;
using paddle::platform::CudnnPlace; using paddle::platform::CUDNNPlace;
if (paddle::platform::dynload::HasCUDNN()) { if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount(); int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context = CUDNNDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i)); new CUDNNDeviceContext(CUDNNPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream()); ASSERT_NE(nullptr, device_context->stream());
......
...@@ -51,9 +51,9 @@ struct GPUPlace { ...@@ -51,9 +51,9 @@ struct GPUPlace {
int device; int device;
}; };
struct CudnnPlace : public GPUPlace { struct CUDNNPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {} CUDNNPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {} explicit CUDNNPlace(int d) : GPUPlace(d) {}
}; };
struct IsGPUPlace : public boost::static_visitor<bool> { struct IsGPUPlace : public boost::static_visitor<bool> {
...@@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> { ...@@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place; typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
// static check number of place types is less equal than // static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
......
...@@ -5,16 +5,22 @@ ...@@ -5,16 +5,22 @@
TEST(Place, Equality) { TEST(Place, Equality) {
paddle::platform::CPUPlace cpu; paddle::platform::CPUPlace cpu;
paddle::platform::GPUPlace g0(0), g1(1), gg0(0); paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0);
EXPECT_EQ(cpu, cpu); EXPECT_EQ(cpu, cpu);
EXPECT_EQ(g0, g0); EXPECT_EQ(g0, g0);
EXPECT_EQ(g1, g1); EXPECT_EQ(g1, g1);
EXPECT_EQ(g0, gg0); EXPECT_EQ(g0, gg0);
EXPECT_EQ(d0, dd0);
EXPECT_NE(g0, g1); EXPECT_NE(g0, g1);
EXPECT_NE(d0, d1);
EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0)); EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu)); EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu));
EXPECT_TRUE(paddle::platform::is_gpu_place(d0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, d0));
} }
TEST(Place, Default) { TEST(Place, Default) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册