Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
0f83674c
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f83674c
编写于
11月 17, 2017
作者:
C
chengduo
提交者:
GitHub
11月 17, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5603 from chengduoZH/Add_conv3d_transpose_cudnn_op
add conv3d_trans_cudnn_op
上级
2113cbfd
c359e39b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
122 addition
and
54 deletion
+122
-54
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+20
-13
paddle/operators/conv_cudnn_op.cu.cc
paddle/operators/conv_cudnn_op.cu.cc
+4
-6
paddle/operators/conv_op.cc
paddle/operators/conv_op.cc
+8
-4
paddle/operators/conv_op.cu.cc
paddle/operators/conv_op.cu.cc
+8
-4
paddle/operators/conv_transpose_cudnn_op.cc
paddle/operators/conv_transpose_cudnn_op.cc
+29
-1
paddle/operators/conv_transpose_cudnn_op.cu.cc
paddle/operators/conv_transpose_cudnn_op.cu.cc
+20
-11
paddle/operators/conv_transpose_op.cc
paddle/operators/conv_transpose_op.cc
+8
-4
paddle/operators/conv_transpose_op.cu.cc
paddle/operators/conv_transpose_op.cu.cc
+8
-4
paddle/operators/pool_cudnn_op.cu.cc
paddle/operators/pool_cudnn_op.cu.cc
+1
-2
paddle/platform/cudnn_helper.h
paddle/platform/cudnn_helper.h
+10
-5
python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py
python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py
+6
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
0f83674c
...
...
@@ -61,6 +61,18 @@ function(op_library TARGET)
set
(
pybind_flag 1
)
endif
()
if
(
"
${
TARGET
}
"
STREQUAL
"compare_op"
)
set
(
pybind_flag 1
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(equal);
\n
"
)
endif
()
# conv_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"conv_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(conv2d);
\n
"
)
endif
()
# pool_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"pool_op"
)
set
(
pybind_flag 1
)
...
...
@@ -68,9 +80,11 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(pool2d);
\n
"
)
endif
()
if
(
"
${
TARGET
}
"
STREQUAL
"compare_op"
)
# pool_cudnn_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"pool_cudnn_op"
)
set
(
pybind_flag 1
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(equal);
\n
"
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(pool2d_cudnn);
\n
"
)
endif
()
# pool_with_index_op contains several operators
...
...
@@ -80,25 +94,18 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(max_pool2d_with_index);
\n
"
)
endif
()
# conv_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"conv_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(conv2d);
\n
"
)
endif
()
# conv_transpose_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"conv_transpose_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(conv2d_transpose);
\n
"
)
endif
()
#
pool_cudnn_op contains several
operators
if
(
"
${
TARGET
}
"
STREQUAL
"
pool
_cudnn_op"
)
#
conv_transpose_cudnn_op contains two
operators
if
(
"
${
TARGET
}
"
STREQUAL
"
conv_transpose
_cudnn_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(
pool2d
_cudnn);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(
conv2d_transpose
_cudnn);
\n
"
)
endif
()
# save_restore_op contains several operators
...
...
paddle/operators/conv_cudnn_op.cu.cc
浏览文件 @
0f83674c
...
...
@@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
input_grad
);
t
.
device
(
ctx
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
// Because beta is zero, it is unnecessary to reset input_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
handle
,
&
alpha
,
cudnn_filter_desc
,
...
...
@@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if
(
filter_grad
)
{
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
filter_grad
);
t
.
device
(
ctx
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
// Because beta is zero, it is unnecessary to reset filter_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
...
...
paddle/operators/conv_op.cc
浏览文件 @
0f83674c
...
...
@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops
::
ConvOpGrad
);
REGISTER_OP_CPU_KERNEL
(
conv2d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
conv3d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/conv_op.cu.cc
浏览文件 @
0f83674c
...
...
@@ -17,11 +17,15 @@
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
conv2d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
conv2d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
conv3d_grad
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
paddle/operators/conv
2d
_transpose_cudnn_op.cc
→
paddle/operators/conv_transpose_cudnn_op.cc
浏览文件 @
0f83674c
...
...
@@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
framework
::
OpAttrChecker
*
op_checker
)
:
Conv2DTransposeOpMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"dilations of convolution operator."
)
.
SetDefault
(
std
::
vector
<
int
>
{
1
,
1
});
.
SetDefault
({
1
,
1
});
AddAttr
<
int
>
(
"workspace_size_MB"
,
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted."
)
.
SetDefault
(
4096
);
}
};
class
CudnnConv3DTransposeOpMaker
:
public
Conv3DTransposeOpMaker
{
public:
CudnnConv3DTransposeOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
Conv3DTransposeOpMaker
(
proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"dilations"
,
"dilations of convolution operator."
)
.
SetDefault
({
1
,
1
,
1
});
AddAttr
<
int
>
(
"workspace_size_MB"
,
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
...
...
@@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL
(
conv2d_transpose_cudnn_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP
(
conv3d_transpose_cudnn
,
ops
::
ConvTransposeOp
,
ops
::
CudnnConv3DTransposeOpMaker
,
conv3d_transpose_cudnn_grad
,
ops
::
ConvTransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
conv3d_transpose_cudnn
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d_transpose_cudnn_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/conv
2d
_transpose_cudnn_op.cu.cc
→
paddle/operators/conv_transpose_cudnn_op.cu.cc
浏览文件 @
0f83674c
...
...
@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
ScopedTensorDescriptor
output_desc
;
ScopedFilterDescriptor
filter_desc
;
ScopedConvolutionDescriptor
conv_desc
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
DataLayout
layout
;
if
(
strides
.
size
()
==
2U
)
{
layout
=
DataLayout
::
kNCHW
;
}
else
{
layout
=
DataLayout
::
kNCDHW
;
}
//
N, M, H, W
//
(N, M, H, W) or (N, M, D, H, W)
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
input
->
dims
()));
//
N, C, O_h, O_w
//
(N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
output
->
dims
()));
//
M, C, K_h, K_w
//
(M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
cudnnFilterDescriptor_t
cudnn_filter_desc
=
filter_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
filter
->
dims
()));
cudnnConvolutionDescriptor_t
cudnn_conv_desc
=
...
...
@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor
conv_desc
;
DataLayout
layout
=
DataLayout
::
kNCHW
;
// Input: (N, M, H, W)
// Input: (N, M, H, W)
or (N, M, D, H, W)
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
input
->
dims
()));
// Output: (N, C, O_
H, O_W
)
// Output: (N, C, O_
h, O_w) or (N, C, O_d, O_h, O_w
)
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
output_grad
->
dims
()));
// Filter (M, C, K_
H, K_W
)
// Filter (M, C, K_
h, K_w) or (M, C, K_d K_h, K_w
)
cudnnFilterDescriptor_t
cudnn_filter_desc
=
filter_desc
.
descriptor
<
T
>
(
layout
,
framework
::
vectorize2int
(
filter
->
dims
()));
...
...
@@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
set_constant
(
ctx
.
device_context
(),
input_grad
,
0
);
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
,
cudnn_filter_desc
,
filter_data
,
cudnn_conv_desc
,
data_algo
,
...
...
@@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if
(
filter_grad
)
{
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
set_constant
(
ctx
.
device_context
(),
filter_grad
,
0
);
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
,
cudnn_input_desc
,
...
...
@@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops
::
CudnnConvTransposeOpKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
conv2d_transpose_cudnn_grad
,
ops
::
CudnnConvTransposeGradOpKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d_transpose_cudnn
,
ops
::
CudnnConvTransposeOpKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d_transpose_cudnn_grad
,
ops
::
CudnnConvTransposeGradOpKernel
<
float
>
);
paddle/operators/conv_transpose_op.cc
浏览文件 @
0f83674c
...
...
@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL
(
conv2d_transpose
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv2d_transpose_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP
(
conv3d_transpose
,
ops
::
ConvTransposeOp
,
ops
::
Conv3DTransposeOpMaker
,
conv3d_transpose_grad
,
ops
::
ConvTransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
conv3d_transpose
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
conv3d_transpose_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
CPUPlace
,
double
>
);
paddle/operators/conv_transpose_op.cu.cc
浏览文件 @
0f83674c
...
...
@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL
(
conv2d_transpose
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv2d_transpose_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d_transpose
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvTransposeKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
REGISTER_OP_GPU_KERNEL
(
conv3d_transpose_grad
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
,
ops
::
GemmConvTransposeGradKernel
<
paddle
::
platform
::
GPUPlace
,
double
>
);
paddle/operators/pool_cudnn_op.cu.cc
浏览文件 @
0f83674c
...
...
@@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
paddle
::
platform
::
GPUPlace
,
T
>
set_zero
;
set_zero
(
ctx
.
device_context
(),
input_grad
,
static_cast
<
T
>
(
0
));
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnPoolingBackward
(
handle
,
cudnn_pool_desc
,
&
alpha
,
cudnn_output_desc
,
output_data
,
...
...
paddle/platform/cudnn_helper.h
浏览文件 @
0f83674c
...
...
@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
} \
} while (false)
enum
class
DataLayout
{
enum
class
DataLayout
{
// Not use
kNHWC
,
kNCHW
,
kNCDHW
,
kNCHW_VECT_C
,
};
...
...
@@ -107,12 +108,15 @@ class CudnnDataType<double> {
}
};
inline
cudnnTensorFormat_t
GetCudnnTensorFormat
(
const
DataLayout
&
order
)
{
inline
cudnnTensorFormat_t
GetCudnnTensorFormat
(
const
DataLayout
&
order
)
{
// Not use
switch
(
order
)
{
case
DataLayout
::
kNHWC
:
return
CUDNN_TENSOR_NHWC
;
case
DataLayout
::
kNCHW
:
return
CUDNN_TENSOR_NCHW
;
case
DataLayout
::
kNCDHW
:
return
CUDNN_TENSOR_NCHW
;
// TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
default:
PADDLE_THROW
(
"Unknown cudnn equivalent for order"
);
}
...
...
@@ -139,7 +143,7 @@ class ScopedTensorDescriptor {
strides
[
i
]
=
dims
[
i
+
1
]
*
strides
[
i
+
1
];
}
// Update tensor descriptor dims setting if groups > 1
// FIXME(typhoonzero): Assume using NCHW order
// FIXME(typhoonzero): Assume using NCHW or
NCDHW or
der
std
::
vector
<
int
>
dims_with_group
(
dims
.
begin
(),
dims
.
end
());
// copy
if
(
groups
>
1
)
{
dims_with_group
[
1
]
=
dims_with_group
[
1
]
/
groups
;
...
...
@@ -176,9 +180,10 @@ class ScopedFilterDescriptor {
const
cudnnDataType_t
type
,
const
std
::
vector
<
int
>&
kernel
,
const
int
groups
=
1
)
{
// filter layout: MCHW, where M is the number of
// filter layout: MCHW
(MCDHW)
, where M is the number of
// output image channels, C is the number of input image channels,
// H and W is height and width of filter.
// D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std
::
vector
<
int
>
kernel_with_group
(
kernel
.
begin
(),
kernel
.
end
());
if
(
groups
>
1
)
{
// M /= groups
...
...
python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py
浏览文件 @
0f83674c
...
...
@@ -108,5 +108,11 @@ class TestWithStride(TestConv3dTransposeOp):
self
.
filter_size
=
[
f_c
,
6
,
3
,
3
,
3
]
# ------------ test_cudnn ------------
class
TestCudnn
(
TestConv3dTransposeOp
):
def
init_op_type
(
self
):
self
.
op_type
=
"conv3d_transpose_cudnn"
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录