未验证 提交 93a2d9c5 编写于 作者: Q QI JUN 提交者: GitHub

add more place test and rename Cudnn to CUDNN (#6621)

* add more place_test and rename Cudnn to CUDNN

* fix ci
上级 77cf7d4f
......@@ -274,7 +274,7 @@ void set_constant_with_place<platform::GPUPlace>(
}
template <>
void set_constant_with_place<platform::CudnnPlace>(
void set_constant_with_place<platform::CUDNNPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
......
......@@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
CUDNNDeviceContext::CUDNNDeviceContext(CUDNNPlace place)
: CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
}
CudnnDeviceContext::~CudnnDeviceContext() {
CUDNNDeviceContext::~CUDNNDeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }
Place CUDNNDeviceContext::GetPlace() const { return CUDNNPlace(); }
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif
......
......@@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_;
};
class CudnnDeviceContext : public CUDADeviceContext {
class CUDNNDeviceContext : public CUDADeviceContext {
public:
explicit CudnnDeviceContext(CudnnPlace place);
virtual ~CudnnDeviceContext();
explicit CUDNNDeviceContext(CUDNNPlace place);
virtual ~CUDNNDeviceContext();
/*! \brief Return place in the device context. */
Place GetPlace() const final;
......@@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext {
private:
cudnnHandle_t cudnn_handle_;
CudnnPlace place_;
CUDNNPlace place_;
};
#endif
......
......@@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) {
}
}
TEST(Device, CudnnDeviceContext) {
using paddle::platform::CudnnDeviceContext;
using paddle::platform::CudnnPlace;
TEST(Device, CUDNNDeviceContext) {
using paddle::platform::CUDNNDeviceContext;
using paddle::platform::CUDNNPlace;
if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i));
CUDNNDeviceContext* device_context =
new CUDNNDeviceContext(CUDNNPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream());
......
......@@ -51,9 +51,9 @@ struct GPUPlace {
int device;
};
struct CudnnPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {}
struct CUDNNPlace : public GPUPlace {
CUDNNPlace() : GPUPlace() {}
explicit CUDNNPlace(int d) : GPUPlace(d) {}
};
struct IsGPUPlace : public boost::static_visitor<bool> {
......@@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
......
......@@ -5,16 +5,22 @@
TEST(Place, Equality) {
paddle::platform::CPUPlace cpu;
paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0);
EXPECT_EQ(cpu, cpu);
EXPECT_EQ(g0, g0);
EXPECT_EQ(g1, g1);
EXPECT_EQ(g0, gg0);
EXPECT_EQ(d0, dd0);
EXPECT_NE(g0, g1);
EXPECT_NE(d0, d1);
EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu));
EXPECT_TRUE(paddle::platform::is_gpu_place(d0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, d0));
}
TEST(Place, Default) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册