From 7c2fd61869f0a45fe0a1a90b421f88475fbd1bcf Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 15 Nov 2017 15:40:30 +0800 Subject: [PATCH] fix data layout --- paddle/operators/pool_cudnn_op.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu index e4389242339..a239fe27d4b 100644 --- a/paddle/operators/pool_cudnn_op.cu +++ b/paddle/operators/pool_cudnn_op.cu @@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); @@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); -- GitLab