diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 364e3ab8d26c3f35f41f319b3d31b63964b93abe..94d1f707b74c2eae17d02771ad7d548e8b908dd9 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -903,29 +903,19 @@ class DepthwiseConvKernel : public framework::OpKernel { "and input channel number is %d", output->dims()[1], input->dims()[1])); } - // transform tensor - Tensor transformed_input(input->type()); - Tensor transformed_output(output->type()); - - if (channel_last) { - ResizeToChannelFirst(context, input, - &transformed_input); - TransToChannelFirst(context, input, &transformed_input); - - ResizeToChannelFirst(context, output, - &transformed_output); - - } else { - transformed_input = *input; - transformed_output = *output; - } // update padding and dilation - auto in_dims = transformed_input.dims(); + auto in_dims = input->dims(); auto filter_dims = filter.dims(); framework::DDim in_data_dims; - in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_format); + if (data_layout != framework::DataLayout::kNHWC) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); @@ -944,16 +934,12 @@ class DepthwiseConvKernel : public framework::OpKernel { if (fuse_relu) { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, - dilations, &transformed_output); + depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, + output, data_layout); } else { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, - dilations, &transformed_output); - } - if (channel_last) { - TransToChannelLast(context, &transformed_output, - output); + depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, + output, data_layout); } } }; @@ -981,33 +967,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel { context.Attr("padding_algorithm"); const std::string data_format = context.Attr("data_format"); - const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - - // transform Tensor - Tensor transformed_input(input->type()); - Tensor transformed_output_grad(output_grad->type()); - - if (channel_last) { - ResizeToChannelFirst(context, input, - &transformed_input); - TransToChannelFirst(context, input, &transformed_input); - - ResizeToChannelFirst(context, output_grad, - &transformed_output_grad); - TransToChannelFirst(context, output_grad, - &transformed_output_grad); - - } else { - transformed_input = *input; - transformed_output_grad = *output_grad; - } - // update padding and dilation - auto in_dims = transformed_input.dims(); + auto in_dims = input->dims(); auto filter_dims = filter.dims(); framework::DDim in_data_dims; - in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + const framework::DataLayout data_layout = + framework::StringToDataLayout(data_format); + if (data_layout != framework::DataLayout::kNHWC) { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } else { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } framework::DDim filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size()); std::vector ksize = framework::vectorize(filter_data_dims); @@ -1025,33 +996,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); - Tensor transformed_input_grad(input_grad->type()); - if (channel_last) { - ResizeToChannelFirst(context, input_grad, - &transformed_input_grad); - - } else { - transformed_input_grad = *input_grad; - } - - set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); + set_zero(dev_ctx, input_grad, static_cast(0)); if (fuse_relu) { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; - depthwiseConvInputGrad(dev_ctx, transformed_input, filter, - transformed_output_grad, strides, paddings, - dilations, &transformed_input_grad); + depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, + paddings, dilations, input_grad, data_layout); } else { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; - depthwiseConvInputGrad(dev_ctx, transformed_input, filter, - transformed_output_grad, strides, paddings, - dilations, &transformed_input_grad); - } - if (channel_last) { - TransToChannelLast(context, &transformed_input_grad, - input_grad); + depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, + paddings, dilations, input_grad, data_layout); } } @@ -1061,15 +1017,13 @@ class DepthwiseConvGradKernel : public framework::OpKernel { if (fuse_relu) { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; - depthwiseConvFilterGrad(dev_ctx, transformed_input, - transformed_output_grad, strides, paddings, - dilations, filter_grad); + depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, + paddings, dilations, filter_grad, data_layout); } else { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; - depthwiseConvFilterGrad(dev_ctx, transformed_input, - transformed_output_grad, strides, paddings, - dilations, filter_grad); + depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, + paddings, dilations, filter_grad, data_layout); } } } diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index d116b620dc1e13973560480a674fe27437dfbefc..5fd543b5c6c5ccabd6d2e741a8ff4a9d1da5902c 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -22,6 +22,7 @@ limitations under the License. */ namespace cub = hipcub; #endif #include "paddle/fluid/operators/math/depthwise_conv.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -52,8 +53,7 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { const int filter_multiplier, const int filter_height, \ const int filter_width, const int stride_height, const int stride_width, \ const int padding_height, const int padding_width, \ - const int dilate_height, const int dilate_width, T *const output_data, \ - const DataLayout data_layout = DataLayout::kNCHW + const int dilate_height, const int dilate_width, T *const output_data // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. @@ -123,7 +123,6 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( const int batch = idx / output_width / output_height / output_channels; const int c_in = c_out / filter_multiplier; - const T* weight = filter_data + c_out * filter_height * filter_width; T value = 0; const int h_in_start = -padding_height + h_out * stride_height; const int w_in_start = -padding_width + w_out * stride_width; @@ -142,13 +141,14 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { int offset = ((batch * input_height + h_in) * input_width + w_in) * - output_channels + + input_channels + c_in; T in_data = input_data[offset]; + const T* weight = filter_data + weight_offset * output_channels + c_out; if (fuse_relu_before_conv) { - value += weight[weight_offset] * max(0.0f, in_data); + value += weight[0] * max(0.0f, in_data); } else { - value += weight[weight_offset] * in_data; + value += weight[0] * in_data; } } weight_offset++; @@ -161,10 +161,10 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( } template -__device__ __inline__ void KernelDepthwiseConvCFilter( +__device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ARG_DEFINE_KernelDepthwiseConv) { - const int kWeghtSize = c_filter * c_filter; - T r_weight[kWeghtSize]; + const int kWeightSize = c_filter * c_filter; + T r_weight[kWeightSize]; const int batch = blockIdx.y; const int c_out = blockIdx.x; const T* weight = filter_data + c_out * c_filter * c_filter; @@ -182,13 +182,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( const int h_in_end = h_in_start + c_filter * dilate_height; const int w_in_end = w_in_start + c_filter * dilate_width; - int in_offset; - if (data_layout != DataLayout::kNHWC) { - in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; - } else { - in_offset = batch * input_height * input_width * input_channels; - } + int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; const int h_end = h_in_end < input_height ? h_in_end : input_height; const int w_end = w_in_end < input_width ? w_in_end : input_width; @@ -201,13 +196,63 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( w_in += dilate_width, w_f++) { if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) { - int offset; - if (data_layout != DataLayout::kNHWC) { - offset = in_offset + h_in * input_width + w_in; + int offset = in_offset + h_in * input_width + w_in; + if (fuse_relu_before_conv) { + value += r_weight[h_f * c_filter + w_f] * + max(0.0f, input_data[offset]); } else { - offset = in_offset + - (h_in * input_width + w_in) * input_channels + c_in; + value += r_weight[h_f * c_filter + w_f] * input_data[offset]; } + } + } + } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; + } + } +} + +template +__device__ __inline__ void KernelDepthwiseConvCFilterNHWC( + ARG_DEFINE_KernelDepthwiseConv) { + const int batch = blockIdx.z; + int h_out = blockIdx.x * dilate_height + blockIdx.y; + if (h_out >= output_height) { + return; + } + int in_offset = batch * input_height * input_width * input_channels; + int out_offset = + (batch * output_height + h_out) * output_width * output_channels; + const int h_in_start = -padding_height + h_out * stride_height; + const int wi_size = (output_width + dilate_width - 1) / dilate_width; + const int kWeightSize = c_filter * c_filter; + T r_weight[kWeightSize]; + + for (int c_out = threadIdx.x; c_out < output_channels; c_out += blockDim.x) { + for (int i = 0; i < c_filter * c_filter; i++) { + const T* weight = filter_data + i * output_channels + c_out; + r_weight[i] = weight[0]; + } + const int c_in = c_out / filter_multiplier; + for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) { + int i_dw = i / wi_size; + int i_wi = i - i_dw * wi_size; + int w_out = i_wi * dilate_width + i_dw; + if (w_out >= output_width) { + continue; + } + T value = 0; + const int w_in_start = -padding_width + w_out * stride_width; + for (int h_in = h_in_start, h_f = 0; h_f < c_filter; + h_in += dilate_height, h_f++) { + for (int w_in = w_in_start, w_f = 0; w_f < c_filter; + w_in += dilate_width, w_f++) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && + w_in < input_width) { + int offset = + in_offset + (h_in * input_width + w_in) * input_channels + c_in; if (fuse_relu_before_conv) { value += r_weight[h_f * c_filter + w_f] * max(0.0f, input_data[offset]); @@ -217,23 +262,14 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( } } } - int index; - if (data_layout != DataLayout::kNHWC) { - index = ((batch * gridDim.x + c_out) * output_height + h_out) * - output_width + - w_out; - } else { - index = ((batch * output_height + h_out) * output_width + w_out) * - gridDim.x + - c_out; - } + int index = out_offset + w_out * output_channels + c_out; output_data[index] = value; } } } template + DataLayout data_layout, bool fuse_relu_before_conv> __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { int final_filter_multiplier = filter_multiplier; int h_stride = stride_height; @@ -244,28 +280,37 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { w_stride = c_stride; } if (c_filter == -1) { - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { KernelDepthwiseConvNCHW( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, final_filter_multiplier, filter_height, filter_width, h_stride, w_stride, padding_height, padding_width, dilate_height, dilate_width, - output_data, data_layout); + output_data); } else { KernelDepthwiseConvNHWC( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, final_filter_multiplier, filter_height, filter_width, h_stride, w_stride, padding_height, padding_width, dilate_height, dilate_width, - output_data, data_layout); + output_data); } } else { - KernelDepthwiseConvCFilter( - input_data, filter_data, batch_size, output_channels, output_height, - output_width, input_channels, input_height, input_width, - final_filter_multiplier, filter_height, filter_width, h_stride, - w_stride, padding_height, padding_width, dilate_height, dilate_width, - output_data, data_layout); + if (data_layout != DataLayout::kNHWC) { + KernelDepthwiseConvCFilterNCHW( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + output_data); + } else { + KernelDepthwiseConvCFilterNHWC( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + output_data); + } } } @@ -280,40 +325,27 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { const int filter_width, const int stride_height, const int stride_width, \ const int padding_height, const int padding_width, \ const int dilate_height, const int dilate_width, \ - T *const input_grad_data, \ - const DataLayout data_layout = DataLayout::kNCHW + T *const input_grad_data template -__device__ __inline__ void KernelDepthwiseConvInputGrad( +__device__ __inline__ void KernelDepthwiseConvInputGradNCHW( ARG_DEFINE_KernelDepthwiseConvInputGrad) { + const int batch = blockIdx.y; + const int c_in = blockIdx.x; for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { - const int batch = blockIdx.y; - const int c_in = blockIdx.x; - const int c_out_start = c_in * filter_multiplier; - int h_out_start = h_in - (filter_height - 1) * dilate_height + padding_height; - int h_out_end = h_in + padding_height; - int w_out_start = w_in - (filter_width - 1) * dilate_width + padding_width; - int w_out_end = w_in + padding_width; T value = 0; - int index; - if (data_layout != DataLayout::kNHWC) { - index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; - } else { - index = - ((batch * input_height + h_in) * input_width + w_in) * gridDim.x + - c_in; - } + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; if (fuse_relu_before_conv) { if (input_data[index] <= 0) { @@ -335,20 +367,67 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( if (h_out % stride_height == 0 && w_out % stride_width == 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { - int output_grad_offset; - if (data_layout != DataLayout::kNHWC) { - output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; - } else { - output_grad_offset = - ((batch * output_height + s_h_out) * output_width + - s_w_out) * - output_channels + - c_out; - } + int output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + value += output_grad_data[output_grad_offset] * + filter_data[filter_offset]; + } + } + } + } + input_grad_data[index] = value; + } + } +} + +template +__device__ __inline__ void KernelDepthwiseConvInputGradNHWC( + ARG_DEFINE_KernelDepthwiseConvInputGrad) { + const int batch = blockIdx.z; + int h_in = blockIdx.x * dilate_height + blockIdx.y; + if (h_in >= input_height) { + return; + } + + for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) { + for (int w_in = threadIdx.y; w_in < input_width; w_in += blockDim.y) { + int h_out_start = + h_in - (filter_height - 1) * dilate_height + padding_height; + int w_out_start = + w_in - (filter_width - 1) * dilate_width + padding_width; + + T value = 0; + int index = ((batch * input_height + h_in) * input_width + w_in) * + input_channels + + c_in; + if (fuse_relu_before_conv) { + if (input_data[index] <= 0) { + input_grad_data[index] = 0; + continue; + } + } + + for (int c_i = 0; c_i < filter_multiplier; c_i++) { + int c_out = c_in * filter_multiplier + c_i; + int weight_offset = filter_height * filter_width; + for (int h_out = h_out_start, h_f = 0; h_f < filter_height; + h_out += dilate_height, h_f++) { + for (int w_out = w_out_start, w_f = 0; w_f < filter_width; + w_out += dilate_width, w_f++) { + weight_offset--; + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + int output_grad_offset = + ((batch * output_height + s_h_out) * output_width + s_w_out) * + output_channels + + c_out; + int filter_offset = weight_offset * output_channels + c_out; value += output_grad_data[output_grad_offset] * filter_data[filter_offset]; } @@ -362,10 +441,10 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( template -__device__ __inline__ void KernelDepthwiseConvInputGradCFilter( +__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW( ARG_DEFINE_KernelDepthwiseConvInputGrad) { - const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1; - T r_weight[kWeghtSize]; + const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1; + T r_weight[kWeightSize]; const int batch = blockIdx.y; const int c_in = blockIdx.x; @@ -379,24 +458,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { - const int batch = blockIdx.y; - const int c_in = blockIdx.x; - int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height; - int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; T value = 0; - int index; - if (data_layout != DataLayout::kNHWC) { - index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; - } else { - index = - ((batch * input_height + h_in) * input_width + w_in) * gridDim.x + - c_in; - } + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; if (fuse_relu_before_conv) { if (input_data[index] <= 0) { input_grad_data[index] = 0; @@ -415,20 +483,11 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( if (h_out % stride_height == 0 && w_out % stride_width == 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_w_out < output_width) { - int output_grad_offset; - if (data_layout != DataLayout::kNHWC) { - output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; - } else { - output_grad_offset = - ((batch * output_height + s_h_out) * output_width + - s_w_out) * - output_channels + - c_out; - } + int output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; value += output_grad_data[output_grad_offset] * r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter]; @@ -441,47 +500,137 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( } } -template +__device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC( + ARG_DEFINE_KernelDepthwiseConvInputGrad) { + int h_in = blockIdx.x * dilate_height + blockIdx.y; + if (h_in >= input_height) { + return; + } + const int kWeightSize = c_filter * c_filter * c_filter_multiplier + 1; + T r_weight[kWeightSize]; + const int batch = blockIdx.z; + const int wi_size = (input_width + dilate_width - 1) / dilate_width; + const int h_out_start = + h_in - (c_filter - 1) * dilate_height + padding_height; + + for (int c_in = threadIdx.x; c_in < input_channels; c_in += blockDim.x) { + for (int c_i = 0; c_i < c_filter_multiplier; c_i++) { + int c_out = c_in * c_filter_multiplier + c_i; + for (int i = 0; i < c_filter * c_filter; i++) + r_weight[i + c_i * c_filter * c_filter] = + filter_data[(c_filter * c_filter - i - 1) * output_channels + + c_out]; + } + for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) { + int i_dw = i / wi_size; + int i_wi = i - i_dw * wi_size; + int w_in = i_wi * dilate_width + i_dw; + if (w_in >= input_width) { + continue; + } + int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; + + T value = 0; + int index = ((batch * input_height + h_in) * input_width + w_in) * + input_channels + + c_in; + if (fuse_relu_before_conv) { + if (input_data[index] <= 0) { + input_grad_data[index] = 0; + continue; + } + } + + for (int c_i = 0; c_i < c_filter_multiplier; c_i++) { + int c_out = c_in * c_filter_multiplier + c_i; + for (int h_out = h_out_start, h_f = 0; h_f < c_filter; + h_out += dilate_height, h_f++) { + for (int w_out = w_out_start, w_f = 0; w_f < c_filter; + w_out += dilate_width, w_f++) { + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + int output_grad_offset = + ((batch * output_height + s_h_out) * output_width + s_w_out) * + output_channels + + c_out; + value += + output_grad_data[output_grad_offset] * + r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter]; + } + } + } + } + input_grad_data[index] = value; + } + } +} + +template __global__ void KernelDepthwiseConvInputGradSp( ARG_DEFINE_KernelDepthwiseConvInputGrad) { - if (c_filter_multiplier == 0) - KernelDepthwiseConvInputGrad( - input_data, output_grad_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - filter_multiplier, filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, dilate_height, - dilate_width, input_grad_data, data_layout); - else if (c_filter == -1) - KernelDepthwiseConvInputGrad( - input_data, output_grad_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, - padding_height, padding_width, dilate_height, dilate_width, - input_grad_data, data_layout); - else - KernelDepthwiseConvInputGradCFilter( - input_data, output_grad_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, - padding_height, padding_width, dilate_height, dilate_width, - input_grad_data, data_layout); + int final_filter_multiplier = filter_multiplier; + int h_stride = stride_height; + int w_stride = stride_width; + if (c_filter_multiplier != 0) { + final_filter_multiplier = c_filter_multiplier; + h_stride = c_stride; + w_stride = c_stride; + } + + if (c_filter_multiplier == 0 || c_filter == -1) { + if (data_layout != DataLayout::kNHWC) { + KernelDepthwiseConvInputGradNCHW( + input_data, output_grad_data, filter_data, batch_size, + output_channels, output_height, output_width, input_channels, + input_height, input_width, final_filter_multiplier, filter_height, + filter_width, h_stride, w_stride, padding_height, padding_width, + dilate_height, dilate_width, input_grad_data); + } else { + KernelDepthwiseConvInputGradNHWC( + input_data, output_grad_data, filter_data, batch_size, + output_channels, output_height, output_width, input_channels, + input_height, input_width, final_filter_multiplier, filter_height, + filter_width, h_stride, w_stride, padding_height, padding_width, + dilate_height, dilate_width, input_grad_data); + } + } else { + if (data_layout != DataLayout::kNHWC) { + KernelDepthwiseConvInputGradCFilterNCHW( + input_data, output_grad_data, filter_data, batch_size, + output_channels, output_height, output_width, input_channels, + input_height, input_width, c_filter_multiplier, filter_height, + filter_width, c_stride, c_stride, padding_height, padding_width, + dilate_height, dilate_width, input_grad_data); + } else { + KernelDepthwiseConvInputGradCFilterNHWC( + input_data, output_grad_data, filter_data, batch_size, + output_channels, output_height, output_width, input_channels, + input_height, input_width, c_filter_multiplier, filter_height, + filter_width, c_stride, c_stride, padding_height, padding_width, + dilate_height, dilate_width, input_grad_data); + } + } } // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template -__device__ __inline__ void KernelDepthwiseConvFilterGrad( +__device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( const T* output_grad_data, const T* input_data, const int num, const int output_channels, const int output_height, const int output_width, const int input_channels, const int input_height, const int input_width, const int filter_multiplier, const int filter_height, const int filter_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* filter_grad_data, - const DataLayout data_layout = DataLayout::kNCHW) { + const int dilate_width, T* filter_grad_data) { T s = 0; - int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; for (int image_w = threadIdx.x; image_w < output_width; @@ -499,45 +648,137 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( if (image_wk < 0 || image_wk >= input_width) continue; #define gaid(N, C, H, W) \ ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) -#define gaid_nhwc(N, H, W, C) \ - ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C)) - int input_id; - if (data_layout != DataLayout::kNHWC) { - input_id = ((bid * (gridDim.z / filter_multiplier) + - kernel_id / filter_multiplier) * - input_height + - image_hk) * - input_width + - image_wk; - if (fuse_relu_before_conv) { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - max(0.0f, input_data[input_id]); - } else { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - input_data[input_id]; - } + int input_id = ((bid * (gridDim.z / filter_multiplier) + + kernel_id / filter_multiplier) * + input_height + + image_hk) * + input_width + + image_wk; + if (fuse_relu_before_conv) { + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + max(0.0f, input_data[input_id]); } else { - input_id = + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + input_data[input_id]; + } +#undef gaid + } + } + } + CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); +} + +template +__device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( + const T* output_grad_data, const T* input_data, const int num, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* filter_grad_data) { + int bid = blockIdx.z; + int image_h = blockIdx.y; + int kernel_iw = blockIdx.x % filter_width; + int kernel_ih = blockIdx.x / filter_width; + for (int kernel_id = threadIdx.x; kernel_id < output_channels; + kernel_id += blockDim.x) { + T s = 0; + int gbid = + ((kernel_id * filter_height) + kernel_ih) * filter_width + kernel_iw; + for (int image_w = threadIdx.y; image_w < output_width; + image_w += blockDim.y) { + int kernel_h = kernel_ih * dilate_height - padding_height; + int kernel_w = kernel_iw * dilate_width - padding_width; + + int image_hk = image_h * stride_height + kernel_h; + int image_wk = image_w * stride_width + kernel_w; + if (image_hk < 0 || image_hk >= input_height) continue; + if (image_wk < 0 || image_wk >= input_width) continue; +#define gaid(N, H, W, C) \ + ((((N)*output_height + (H)) * output_width + (W)) * output_channels + (C)) + int input_id = + ((bid * input_height + image_hk) * input_width + image_wk) * + input_channels + + kernel_id / filter_multiplier; + if (fuse_relu_before_conv) { + s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * + max(0.0f, input_data[input_id]); + } else { + s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * + input_data[input_id]; + } +#undef gaid + } + platform::CudaAtomicAdd(&filter_grad_data[gbid], s); + } +} + +template +__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( + const T* output_grad_data, const T* input_data, const int num, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* filter_grad_data) { + const int bid = blockIdx.z; + int image_h = blockIdx.x * dilate_height + blockIdx.y; + if (image_h >= output_height) { + return; + } + const int kWeightSize = c_filter * c_filter; + T r_weight[kWeightSize]; + const int wi_size = (output_width + dilate_width - 1) / dilate_width; + + for (int kernel_id = threadIdx.x; kernel_id < output_channels; + kernel_id += blockDim.x) { + for (int i = 0; i < c_filter * c_filter; ++i) { + r_weight[i] = 0; + } + for (int i = threadIdx.y; i < wi_size * dilate_width; i += blockDim.y) { + int i_dw = i / wi_size; + int i_wi = i - i_dw * wi_size; + int image_w = i_wi * dilate_width + i_dw; + if (image_w >= output_width) { + continue; + } + for (int kernel_ih = 0; kernel_ih < c_filter; ++kernel_ih) { + for (int kernel_iw = 0; kernel_iw < c_filter; ++kernel_iw) { + int kernel_h = kernel_ih * dilate_height - padding_height; + int kernel_w = kernel_iw * dilate_width - padding_width; + int image_hk = image_h * stride_height + kernel_h; + int image_wk = image_w * stride_width + kernel_w; + if (image_hk < 0 || image_hk >= input_height) continue; + if (image_wk < 0 || image_wk >= input_width) continue; + int input_id = ((bid * input_height + image_hk) * input_width + image_wk) * - (gridDim.z / filter_multiplier) + + input_channels + kernel_id / filter_multiplier; + int output_id = + ((bid * output_height + image_h) * output_width + image_w) * + output_channels + + kernel_id; + T s = 0; if (fuse_relu_before_conv) { - s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] * - max(0.0f, input_data[input_id]); + s = output_grad_data[output_id] * max(0.0f, input_data[input_id]); } else { - s += output_grad_data[gaid_nhwc(bid, image_h, image_w, kernel_id)] * - input_data[input_id]; + s = output_grad_data[output_id] * input_data[input_id]; } + r_weight[kernel_ih * c_filter + kernel_iw] += s; } - -#undef gaid } } + for (int i = 0; i < c_filter * c_filter; ++i) { + T* weight = filter_grad_data + i * output_channels + kernel_id; + platform::CudaAtomicAdd(&weight[0], r_weight[i]); + } } - CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); } -template +template __global__ void KernelDepthwiseConvFilterGradSp( const T* output_grad_data, const T* input_data, const int num, const int output_channels, const int output_height, const int output_width, @@ -545,22 +786,49 @@ __global__ void KernelDepthwiseConvFilterGradSp( const int filter_multiplier, const int filter_height, const int filter_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* filter_grad_data, - const DataLayout data_layout = DataLayout::kNCHW) { - if (c_filter_multiplier == 0) - KernelDepthwiseConvFilterGrad( - output_grad_data, input_data, num, output_channels, output_height, - output_width, input_channels, input_height, input_width, - filter_multiplier, filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, dilate_height, - dilate_width, filter_grad_data, data_layout); - else - KernelDepthwiseConvFilterGrad( - output_grad_data, input_data, num, output_channels, output_height, - output_width, input_channels, input_height, input_width, - c_filter_multiplier, filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, dilate_height, - dilate_width, filter_grad_data, data_layout); + const int dilate_width, T* filter_grad_data) { + int final_filter_multiplier = filter_multiplier; + int h_stride = stride_height; + int w_stride = stride_width; + if (c_filter_multiplier != 0) { + final_filter_multiplier = c_filter_multiplier; + h_stride = c_stride; + w_stride = c_stride; + } + if (c_filter_multiplier == 0 || c_filter == -1) { + if (data_layout != DataLayout::kNHWC) { + KernelDepthwiseConvFilterGradNCHW( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + filter_grad_data); + } else { + KernelDepthwiseConvFilterGradNHWC( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + filter_grad_data); + } + } else { + if (data_layout != DataLayout::kNHWC) { + KernelDepthwiseConvFilterGradNCHW( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + filter_grad_data); + } else { + KernelDepthwiseConvFilterGradCFilterNHWC( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + filter_grad_data); + } + } } /* @@ -608,19 +876,45 @@ class DepthwiseConvFunctor(); T* output_data = output->mutable_data(context.GetPlace()); + framework::Tensor filter_hwc; + if (data_layout == DataLayout::kNHWC) { + framework::DDim filter_hwc_dims({filter.dims()[2], filter.dims()[3], + filter.dims()[0], filter.dims()[1]}); + filter_hwc.Resize(filter_hwc_dims); + filter_hwc.mutable_data(context.GetPlace()); + std::vector perm_axis({2, 3, 0, 1}); + math::TransposeNormal trans; + trans(context, filter, &filter_hwc, perm_axis); + filter_data = filter_hwc.data(); + } + int thread = 512; - if (output_width > 1024 && output_width <= 2048) - thread = (output_width - 1) / 2 + 1; - else if (output_width > 512 && output_width <= 1024) - thread = output_width; + int blocks; + dim3 threads; + dim3 grid; + if (data_layout != DataLayout::kNHWC) { + if (output_width > 1024 && output_width <= 2048) + thread = (output_width - 1) / 2 + 1; + else if (output_width > 512 && output_width <= 1024) + thread = output_width; +#ifdef __HIPCC__ + thread = std::min(thread, 256); +#endif + blocks = std::min(std::max(thread / output_width, 1), output_height); + threads = dim3(std::min(output_width, thread), blocks, 1); + grid = dim3(output_channels, batch_size, 1); + } else { #ifdef __HIPCC__ - thread = std::min(thread, 256); + thread = std::min(thread, 256); #endif - int blocks = std::min(std::max(thread / output_width, 1), output_height); - dim3 threads(std::min(output_width, thread), blocks, 1); - dim3 grid(output_channels, batch_size, 1); + blocks = std::min( + std::max(thread / output_channels, 1), + ((output_width + dilate_width - 1) / dilate_width) * dilate_width); + threads = dim3(std::min(output_channels, thread), blocks, 1); + grid = dim3((output_height + dilate_height - 1) / dilate_height, + dilate_height, batch_size); + } int filter_multiplier = output_channels / input_channels; - int nums_output = batch_size * output_channels * output_height * output_width; #ifdef __HIPCC__ @@ -631,26 +925,37 @@ class DepthwiseConvFunctor<<>>( \ - input_data, filter_data, batch_size, output_channels, output_height, \ - output_width, input_channels, input_height, input_width, \ - filter_multiplier, ksize_height, ksize_width, stride_height, \ - stride_width, padding_height, padding_width, dilate_height, \ - dilate_width, output_data, data_layout); \ - return; \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ + if (c_filter_multiplier == 0 || \ + filter_multiplier == c_filter_multiplier && \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ + if (c_filter == -1) { \ + threads.x = block_size; \ + grid.x = grid_size; \ + threads.y = threads.z = grid.y = grid.z = 1; \ + } \ + if (data_layout != DataLayout::kNHWC) { \ + KernelDepthwiseConvSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW, \ + fuse_relu_before_conv><<>>( \ + input_data, filter_data, batch_size, output_channels, output_height, \ + output_width, input_channels, input_height, input_width, \ + filter_multiplier, ksize_height, ksize_width, stride_height, \ + stride_width, padding_height, padding_width, dilate_height, \ + dilate_width, output_data); \ + } else { \ + KernelDepthwiseConvSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC, \ + fuse_relu_before_conv><<>>( \ + input_data, filter_data, batch_size, output_channels, output_height, \ + output_width, input_channels, input_height, input_width, \ + filter_multiplier, ksize_height, ksize_width, stride_height, \ + stride_width, padding_height, padding_width, dilate_height, \ + dilate_width, output_data); \ + } \ + return; \ } check_case(1, 1, 3); check_case(1, 1, 5); @@ -714,32 +1019,67 @@ class DepthwiseConvInputGradFunctor(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + framework::Tensor filter_hwc; + if (data_layout == DataLayout::kNHWC) { + framework::DDim filter_hwc_dims({filter.dims()[2], filter.dims()[3], + filter.dims()[0], filter.dims()[1]}); + filter_hwc.Resize(filter_hwc_dims); + filter_hwc.mutable_data(context.GetPlace()); + std::vector perm_axis({2, 3, 0, 1}); + math::TransposeNormal trans; + trans(context, filter, &filter_hwc, perm_axis); + filter_data = filter_hwc.data(); + } + int thread = 512; - if (input_width > 1024 && input_width <= 2048) - thread = (input_width - 1) / 2 + 1; - else if (input_width > 512 && input_width <= 1024) - thread = input_width; - int blocks = std::min(std::max(thread / input_width, 1), input_height); - dim3 threads(std::min(input_width, thread), blocks, 1); - dim3 grid(input_channels, batch_size, 1); + int blocks; + dim3 threads; + dim3 grid; + if (data_layout != DataLayout::kNHWC) { + if (input_width > 1024 && input_width <= 2048) { + thread = (input_width - 1) / 2 + 1; + } else if (input_width > 512 && input_width <= 1024) { + thread = input_width; + } + blocks = std::min(std::max(thread / input_width, 1), input_height); + threads = dim3(std::min(input_width, thread), blocks, 1); + grid = dim3(input_channels, batch_size, 1); + } else { + blocks = std::min( + std::max(thread / input_channels, 1), + ((input_width + dilate_width - 1) / dilate_width) * dilate_width); + threads = dim3(std::min(input_channels, thread), blocks, 1); + grid = dim3((input_height + dilate_height - 1) / dilate_height, + dilate_height, batch_size); + } int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier, c_stride, c_filter) \ - if (c_filter_multiplier == 0 || \ - filter_multiplier == c_filter_multiplier && \ - stride_height == stride_width && stride_height == c_stride && \ - (ksize_height == ksize_width && ksize_height == c_filter || \ - c_filter == -1)) { \ - KernelDepthwiseConvInputGradSp< \ - T, c_filter_multiplier, c_stride, c_filter, \ - fuse_relu_before_conv><<>>( \ - input_data, output_grad_data, filter_data, batch_size, \ - output_channels, output_height, output_width, input_channels, \ - input_height, input_width, filter_multiplier, ksize_height, \ - ksize_width, stride_height, stride_width, padding_height, \ - padding_width, dilate_height, dilate_width, input_grad_data, \ - data_layout); \ - return; \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ + if (c_filter_multiplier == 0 || \ + filter_multiplier == c_filter_multiplier && \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ + if (data_layout != DataLayout::kNHWC) { \ + KernelDepthwiseConvInputGradSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW, \ + fuse_relu_before_conv><<>>( \ + input_data, output_grad_data, filter_data, batch_size, \ + output_channels, output_height, output_width, input_channels, \ + input_height, input_width, filter_multiplier, ksize_height, \ + ksize_width, stride_height, stride_width, padding_height, \ + padding_width, dilate_height, dilate_width, input_grad_data); \ + } else { \ + KernelDepthwiseConvInputGradSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC, \ + fuse_relu_before_conv><<>>( \ + input_data, output_grad_data, filter_data, batch_size, \ + output_channels, output_height, output_width, input_channels, \ + input_height, input_width, filter_multiplier, ksize_height, \ + ksize_width, stride_height, stride_width, padding_height, \ + padding_width, dilate_height, dilate_width, input_grad_data); \ + } \ + return; \ } check_case(1, 1, 3); check_case(1, 1, 5); @@ -802,30 +1142,95 @@ class DepthwiseConvFilterGradFunctormutable_data(context.GetPlace()); int block_size = 512; - if (output_width > 1024 && output_width <= 2048) - block_size = (output_width - 1) / 2 + 1; - else if (output_width > 512 && output_width <= 1024) - block_size = output_width; - int crop_output_height = - std::min(std::max(block_size / output_width, 1), output_height); - dim3 grid(ksize_width, ksize_height, output_channels); - dim3 threads(std::min(output_width, block_size), crop_output_height, 1); + int blocks; + dim3 threads; + dim3 grid; + if (data_layout != DataLayout::kNHWC) { + if (output_width > 1024 && output_width <= 2048) { + block_size = (output_width - 1) / 2 + 1; + } else if (output_width > 512 && output_width <= 1024) { + block_size = output_width; + } + blocks = std::min(std::max(block_size / output_width, 1), output_height); + grid = dim3(ksize_width, ksize_height, output_channels); + threads = dim3(std::min(output_width, block_size), blocks, 1); + } else { + blocks = std::min( + std::max(block_size / output_channels, 1), + ((output_width + dilate_width - 1) / dilate_width) * dilate_width); + grid = dim3((output_height + dilate_height - 1) / dilate_height, + dilate_height, batch_size); + threads = dim3(std::min(output_channels, block_size), blocks, 1); + } int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier) \ - if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \ - KernelDepthwiseConvFilterGradSp< \ - T, c_filter_multiplier, \ - fuse_relu_before_conv><<>>( \ - output_grad_data, input_data, batch_size, output_channels, \ - output_height, output_width, input_channels, input_height, \ - input_width, filter_multiplier, ksize_height, ksize_width, \ - stride_height, stride_width, padding_height, padding_width, \ - dilate_height, dilate_width, filter_grad_data, data_layout); \ - return; \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ + if (c_filter_multiplier == 0 || \ + filter_multiplier == c_filter_multiplier && \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ + if (data_layout != DataLayout::kNHWC) { \ + KernelDepthwiseConvFilterGradSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNCHW, \ + fuse_relu_before_conv><<>>( \ + output_grad_data, input_data, batch_size, output_channels, \ + output_height, output_width, input_channels, input_height, \ + input_width, filter_multiplier, ksize_height, ksize_width, \ + stride_height, stride_width, padding_height, padding_width, \ + dilate_height, dilate_width, filter_grad_data); \ + } else { \ + framework::Tensor filter_grad_hwc; \ + if (c_filter != -1) { \ + framework::DDim filter_grad_hwc_dims( \ + {filter_grad->dims()[2], filter_grad->dims()[3], \ + filter_grad->dims()[0], filter_grad->dims()[1]}); \ + filter_grad_hwc.Resize(filter_grad_hwc_dims); \ + filter_grad_hwc.mutable_data(context.GetPlace()); \ + math::SetConstant set_zero; \ + set_zero(context, &filter_grad_hwc, static_cast(0)); \ + filter_grad_data = filter_grad_hwc.data(); \ + } else { \ + block_size = 512; \ + if (output_channels > 1024 && output_channels <= 2048) { \ + block_size = (output_channels - 1) / 2 + 1; \ + } else if (output_channels > 512 && output_channels <= 1024) { \ + block_size = output_channels; \ + } \ + blocks = \ + std::min(std::max(block_size / output_channels, 1), output_width); \ + grid = dim3(ksize_width * ksize_height, output_height, batch_size); \ + threads = dim3(std::min(output_channels, block_size), blocks, 1); \ + } \ + KernelDepthwiseConvFilterGradSp< \ + T, c_filter_multiplier, c_stride, c_filter, DataLayout::kNHWC, \ + fuse_relu_before_conv><<>>( \ + output_grad_data, input_data, batch_size, output_channels, \ + output_height, output_width, input_channels, input_height, \ + input_width, filter_multiplier, ksize_height, ksize_width, \ + stride_height, stride_width, padding_height, padding_width, \ + dilate_height, dilate_width, filter_grad_data); \ + if (c_filter != -1) { \ + std::vector perm_axis({2, 3, 0, 1}); \ + math::TransposeNormal trans; \ + trans(context, filter_grad_hwc, filter_grad, perm_axis); \ + } \ + } \ + return; \ } - check_case(1); - check_case(0); + check_case(1, 1, 3); + check_case(1, 1, 5); + check_case(1, 1, -1); + check_case(1, 2, 3); + check_case(1, 2, 5); + check_case(1, 2, -1); + check_case(2, 1, 3); + check_case(2, 1, 5); + check_case(2, 1, -1); + check_case(2, 2, 3); + check_case(2, 2, 5); + check_case(2, 2, -1); + check_case(0, 0, -1); #undef check_case } }; diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 75dc62e530d0db81ee4126dc76918e2f08713d30..3a520615625324f8bc675fbbe855016d22fe8d2e 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -112,10 +112,6 @@ def _conv_nd(x, # Due to the poor performance of NHWC, we transpose the input to NCHW. origin_format = data_format - if origin_format == "NHWC" and op_type == "depthwise_conv2d": - x = nn.transpose(x, perm=[0, 3, 1, 2]) - data_format = "NCHW" - channel_dim = 1 if in_dygraph_mode(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', @@ -159,10 +155,6 @@ def _conv_nd(x, 'use_mkldnn': use_mkldnn}) else: out = pre_bias - - if origin_format == "NHWC" and op_type == "depthwise_conv2d": - out = nn.transpose(out, perm=[0, 2, 3, 1]) - return out