From c18f1bd7168c663a80803f336d4c6c7bcb5d7239 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 31 Oct 2019 10:38:27 +0800 Subject: [PATCH] fix the bug of conv_transpose:compatible with Anylayout setting, test=develop (#20897) --- .../operators/conv_transpose_cudnn_op.cu | 2 +- paddle/fluid/operators/conv_transpose_op.h | 4 +- paddle/fluid/operators/math/depthwise_conv.cu | 58 +++++++++---------- paddle/fluid/operators/math/im2col.cc | 2 +- paddle/fluid/operators/math/im2col.cu | 26 ++++----- 5 files changed, 47 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index 82eb571240..15e8f38312 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -316,7 +316,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { int user_workspace_size = ctx.Attr("workspace_size_MB"); const std::string data_layout_str = ctx.Attr("data_format"); const paddle::operators::DataLayout data_layout = - (data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC); + (data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC); // if channel_last, transpose to channel_first Tensor input_transpose; diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 8bcef8e49a..fa3bb84b06 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -328,7 +328,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice, data_layout); } - output_batch_vec.push_back(out_slice); + if (data_layout == framework::DataLayout::kNHWC) { + output_batch_vec.push_back(out_slice); + } } if (data_layout == framework::DataLayout::kNHWC) { concat_functor(dev_ctx, output_batch_vec, static_cast(D - 2), diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 28083a1b20..2c686c8ba7 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { const int w_in_end = w_in_start + filter_width * dilate_width; int in_offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { in_offset = ((batch * input_channels + c_in) * input_height) * input_width; } else { @@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { int offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { offset = in_offset + h_in * input_width + w_in; } else { offset = in_offset + @@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { } } int index; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { index = ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + w_out; @@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( const int w_in_end = w_in_start + c_filter * dilate_width; int in_offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { in_offset = ((batch * input_channels + c_in) * input_height) * input_width; } else { @@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) { int offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { offset = in_offset + h_in * input_width + w_in; } else { offset = in_offset + @@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( } } int index; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { index = ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + w_out; @@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( T value = 0; int index; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { index = ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + w_in; @@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { int output_grad_offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { output_grad_offset = ((batch * output_channels + c_out) * output_height + s_h_out) * @@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( T value = 0; int index; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { index = ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + w_in; @@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { int output_grad_offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { output_grad_offset = ((batch * output_channels + c_out) * output_height + s_h_out) * @@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( #define gaid_nhwc(N, H, W, C) \ ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C)) int input_id; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { input_id = ((bid * (gridDim.z / filter_multiplier) + kernel_id / filter_multiplier) * input_height + @@ -528,19 +528,19 @@ class DepthwiseConvFunctordims()[1] + (data_layout != DataLayout::kNHWC ? output->dims()[1] : output->dims()[3]); const int output_height = - (data_layout == DataLayout::kNCHW ? output->dims()[2] + (data_layout != DataLayout::kNHWC ? output->dims()[2] : output->dims()[1]); const int output_width = - (data_layout == DataLayout::kNCHW ? output->dims()[3] + (data_layout != DataLayout::kNHWC ? output->dims()[3] : output->dims()[2]); const int ksize_height = filter.dims()[2]; const int ksize_width = filter.dims()[3]; @@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctordims()[2]; const int ksize_width = filter_grad->dims()[3]; diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index 835998fdce..094a723782 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -115,7 +115,7 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { int im_offset; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { im_offset = (c_im * im_height + im_row_idx) * im_width + im_col_idx; } else { diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index ffb598dced..97719300da 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -33,14 +33,14 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { - int w_out = (data_layout == DataLayout::kNCHW + int w_out = (data_layout != DataLayout::kNHWC ? index % col_width : (index / input_channels) % col_width); - int h_out = (data_layout == DataLayout::kNCHW + int h_out = (data_layout != DataLayout::kNHWC ? (index / col_width) % col_height : (index / input_channels / col_width) % col_height); int channel_in = - (data_layout == DataLayout::kNCHW ? index / col_width / col_height + (data_layout != DataLayout::kNHWC ? index / col_width / col_height : index % input_channels); int channel_out = channel_in * filter_height * filter_width; int h_in = h_out * stride_height - padding_height; @@ -52,7 +52,7 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, int rIdx = h_in + i * dilation_h; int cIdx = w_in + j * dilation_w; int im_idx; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { im_idx = (channel_in * im_height + rIdx) * im_width + cIdx; } else { im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in; @@ -86,11 +86,11 @@ class Im2ColFunctordims()[1]; int filter_width = col->dims()[2]; int col_height = col->dims()[3]; @@ -127,14 +127,14 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, if (index < n) { T val = 0; - int w = (data_layout == DataLayout::kNCHW + int w = (data_layout != DataLayout::kNHWC ? index % im_width + padding_width : (index / input_channels) % im_width + padding_width); - int h = (data_layout == DataLayout::kNCHW + int h = (data_layout != DataLayout::kNHWC ? (index / im_width) % im_height + padding_height : (index / input_channels / im_width) % im_height + padding_height); - int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height + int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height : index % input_channels); // compute the start and end of the output @@ -187,11 +187,11 @@ class Col2ImFunctordims()[0] : im->dims()[2]); + (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]); int im_height = - (data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); + (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]); int im_width = - (data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); + (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]); int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; -- GitLab