diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu index e4389242339ff03762a91bdc5359827d4bf0eafc..a239fe27d4b58d2ace07434901ebb611adea9e50 100644 --- a/paddle/operators/pool_cudnn_op.cu +++ b/paddle/operators/pool_cudnn_op.cu @@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); @@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedPoolingDescriptor pool_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims()));