diff --git a/paddle/operators/conv_transpose_cudnn_op.cu b/paddle/operators/conv_transpose_cudnn_op.cu index cd31896f2c06e6b7e11841eb5e8533d717e76b16..00e0ec255dcb11a71bc83f5974627d9364f02de7 100644 --- a/paddle/operators/conv_transpose_cudnn_op.cu +++ b/paddle/operators/conv_transpose_cudnn_op.cu @@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { ScopedTensorDescriptor output_desc; ScopedFilterDescriptor filter_desc; ScopedConvolutionDescriptor conv_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } - // N, M, H, W + // (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); - // N, C, O_h, O_w + // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize2int(output->dims())); - // M, C, K_h, K_w + // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims())); cudnnConvolutionDescriptor_t cudnn_conv_desc = @@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - // Input: (N, M, H, W) + // Input: (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); - // Output: (N, C, O_H, O_W) + // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize2int(output_grad->dims())); - // Filter (M, C, K_H, K_W) + // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims())); diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index ce3421a3cb840e4c1e872eea12dedc1150c85962..8d75fceae8bfba42471202a8f47438c140a2fa1c 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -63,9 +60,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { } \ } while (false) -enum class DataLayout { +enum class DataLayout { // Not use kNHWC, kNCHW, + kNCDHW, kNCHW_VECT_C, }; @@ -107,12 +105,15 @@ class CudnnDataType { } }; -inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { +inline cudnnTensorFormat_t GetCudnnTensorFormat( + const DataLayout& order) { // Not use switch (order) { case DataLayout::kNHWC: return CUDNN_TENSOR_NHWC; case DataLayout::kNCHW: return CUDNN_TENSOR_NCHW; + case DataLayout::kNCDHW: + return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW default: PADDLE_THROW("Unknown cudnn equivalent for order"); } @@ -139,7 +140,7 @@ class ScopedTensorDescriptor { strides[i] = dims[i + 1] * strides[i + 1]; } // Update tensor descriptor dims setting if groups > 1 - // FIXME(typhoonzero): Assume using NCHW order + // FIXME(typhoonzero): Assume using NCHW or NCDHW order std::vector dims_with_group(dims.begin(), dims.end()); // copy if (groups > 1) { dims_with_group[1] = dims_with_group[1] / groups;