Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a93a59ec
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a93a59ec
编写于
11月 13, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cudnn 3d unit test
上级
93c6e52a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
40 addition
and
2 deletion
+40
-2
paddle/platform/cudnn_helper.h
paddle/platform/cudnn_helper.h
+6
-2
paddle/platform/cudnn_helper_test.cc
paddle/platform/cudnn_helper_test.cc
+34
-0
未找到文件。
paddle/platform/cudnn_helper.h
浏览文件 @
a93a59ec
...
@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
...
@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \
} \
} while (false)
} while (false)
enum
class
DataLayout
{
enum
class
DataLayout
{
// Not use
kNHWC
,
kNHWC
,
kNCHW
,
kNCHW
,
kNCDHW
,
kNCHW_VECT_C
,
kNCHW_VECT_C
,
};
};
...
@@ -107,12 +108,15 @@ class CudnnDataType<double> {
...
@@ -107,12 +108,15 @@ class CudnnDataType<double> {
}
}
};
};
inline
cudnnTensorFormat_t
GetCudnnTensorFormat
(
const
DataLayout
&
order
)
{
inline
cudnnTensorFormat_t
GetCudnnTensorFormat
(
const
DataLayout
&
order
)
{
// Not use
switch
(
order
)
{
switch
(
order
)
{
case
DataLayout
::
kNHWC
:
case
DataLayout
::
kNHWC
:
return
CUDNN_TENSOR_NHWC
;
return
CUDNN_TENSOR_NHWC
;
case
DataLayout
::
kNCHW
:
case
DataLayout
::
kNCHW
:
return
CUDNN_TENSOR_NCHW
;
return
CUDNN_TENSOR_NCHW
;
case
DataLayout
::
kNCDHW
:
return
CUDNN_TENSOR_NCHW
;
// TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default:
default:
PADDLE_THROW
(
"Unknown cudnn equivalent for order"
);
PADDLE_THROW
(
"Unknown cudnn equivalent for order"
);
}
}
...
...
paddle/platform/cudnn_helper_test.cc
浏览文件 @
a93a59ec
...
@@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) {
...
@@ -38,6 +38,26 @@ TEST(CudnnHelper, ScopedTensorDescriptor) {
EXPECT_EQ
(
strides
[
2
],
6
);
EXPECT_EQ
(
strides
[
2
],
6
);
EXPECT_EQ
(
strides
[
1
],
36
);
EXPECT_EQ
(
strides
[
1
],
36
);
EXPECT_EQ
(
strides
[
0
],
144
);
EXPECT_EQ
(
strides
[
0
],
144
);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor
tensor5d_desc
;
std
::
vector
<
int
>
shape_5d
=
{
2
,
4
,
6
,
6
,
6
};
auto
desc_5d
=
tensor5d_desc
.
descriptor
<
float
>
(
DataLayout
::
kNCDHW
,
shape_5d
);
std
::
vector
<
int
>
dims_5d
(
5
);
std
::
vector
<
int
>
strides_5d
(
5
);
paddle
::
platform
::
dynload
::
cudnnGetTensorNdDescriptor
(
desc_5d
,
5
,
&
type
,
&
nd
,
dims_5d
.
data
(),
strides_5d
.
data
());
EXPECT_EQ
(
nd
,
5
);
for
(
size_t
i
=
0
;
i
<
dims_5d
.
size
();
++
i
)
{
EXPECT_EQ
(
dims_5d
[
i
],
shape_5d
[
i
]);
}
EXPECT_EQ
(
strides_5d
[
4
],
1
);
EXPECT_EQ
(
strides_5d
[
3
],
6
);
EXPECT_EQ
(
strides_5d
[
2
],
36
);
EXPECT_EQ
(
strides_5d
[
1
],
216
);
EXPECT_EQ
(
strides_5d
[
0
],
864
);
}
}
TEST
(
CudnnHelper
,
ScopedFilterDescriptor
)
{
TEST
(
CudnnHelper
,
ScopedFilterDescriptor
)
{
...
@@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) {
...
@@ -60,6 +80,20 @@ TEST(CudnnHelper, ScopedFilterDescriptor) {
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
EXPECT_EQ
(
kernel
[
i
],
shape
[
i
]);
EXPECT_EQ
(
kernel
[
i
],
shape
[
i
]);
}
}
ScopedFilterDescriptor
filter_desc_4d
;
std
::
vector
<
int
>
shape_4d
=
{
2
,
3
,
3
,
3
};
auto
desc_4d
=
filter_desc
.
descriptor
<
float
>
(
DataLayout
::
kNCDHW
,
shape_4d
);
std
::
vector
<
int
>
kernel_4d
(
4
);
paddle
::
platform
::
dynload
::
cudnnGetFilterNdDescriptor
(
desc_4d
,
4
,
&
type
,
&
format
,
&
nd
,
kernel_4d
.
data
());
EXPECT_EQ
(
GetCudnnTensorFormat
(
DataLayout
::
kNCHW
),
format
);
EXPECT_EQ
(
nd
,
4
);
for
(
size_t
i
=
0
;
i
<
shape_4d
.
size
();
++
i
)
{
EXPECT_EQ
(
kernel_4d
[
i
],
shape_4d
[
i
]);
}
}
}
TEST
(
CudnnHelper
,
ScopedConvolutionDescriptor
)
{
TEST
(
CudnnHelper
,
ScopedConvolutionDescriptor
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录