提交 a93a59ec 编写于 作者: C chengduoZH

add cudnn 3d unit test

上级 93c6e52a
...@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \ } \
} while (false) } while (false)
enum class DataLayout { enum class DataLayout { // Not use
kNHWC, kNHWC,
kNCHW, kNCHW,
kNCDHW,
kNCHW_VECT_C, kNCHW_VECT_C,
}; };
...@@ -107,12 +108,15 @@ class CudnnDataType<double> { ...@@ -107,12 +108,15 @@ class CudnnDataType<double> {
} }
}; };
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { inline cudnnTensorFormat_t GetCudnnTensorFormat(
const DataLayout& order) { // Not use
switch (order) { switch (order) {
case DataLayout::kNHWC: case DataLayout::kNHWC:
return CUDNN_TENSOR_NHWC; return CUDNN_TENSOR_NHWC;
case DataLayout::kNCHW: case DataLayout::kNCHW:
return CUDNN_TENSOR_NCHW; return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default: default:
PADDLE_THROW("Unknown cudnn equivalent for order"); PADDLE_THROW("Unknown cudnn equivalent for order");
} }
......
...@@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) { ...@@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) {
EXPECT_EQ(strides[2], 6); EXPECT_EQ(strides[2], 6);
EXPECT_EQ(strides[1], 36); EXPECT_EQ(strides[1], 36);
EXPECT_EQ(strides[0], 144); EXPECT_EQ(strides[0], 144);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor tensor5d_desc;
std::vector<int> shape_5d = {2, 4, 6, 6, 6};
auto desc_5d = tensor5d_desc.descriptor<float>(DataLayout::kNCDHW, shape_5d);
std::vector<int> dims_5d(5);
std::vector<int> strides_5d(5);
paddle::platform::dynload::cudnnGetTensorNdDescriptor(
desc_5d, 5, &type, &nd, dims_5d.data(), strides_5d.data());
EXPECT_EQ(nd, 5);
for (size_t i = 0; i < dims_5d.size(); ++i) {
EXPECT_EQ(dims_5d[i], shape_5d[i]);
}
EXPECT_EQ(strides_5d[4], 1);
EXPECT_EQ(strides_5d[3], 6);
EXPECT_EQ(strides_5d[2], 36);
EXPECT_EQ(strides_5d[1], 216);
EXPECT_EQ(strides_5d[0], 864);
} }
TEST(CudnnHelper, ScopedFilterDescriptor) { TEST(CudnnHelper, ScopedFilterDescriptor) {
...@@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) { ...@@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
EXPECT_EQ(kernel[i], shape[i]); EXPECT_EQ(kernel[i], shape[i]);
} }
ScopedFilterDescriptor filter_desc_4d;
std::vector<int> shape_4d = {2, 3, 3, 3};
auto desc_4d = filter_desc.descriptor<float>(DataLayout::kNCDHW, shape_4d);
std::vector<int> kernel_4d(4);
paddle::platform::dynload::cudnnGetFilterNdDescriptor(
desc_4d, 4, &type, &format, &nd, kernel_4d.data());
EXPECT_EQ(GetCudnnTensorFormat(DataLayout::kNCHW), format);
EXPECT_EQ(nd, 4);
for (size_t i = 0; i < shape_4d.size(); ++i) {
EXPECT_EQ(kernel_4d[i], shape_4d[i]);
}
} }
TEST(CudnnHelper, ScopedConvolutionDescriptor) { TEST(CudnnHelper, ScopedConvolutionDescriptor) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册