diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 2c686c8ba7f56e70b283c160b1715869ac15b904..882b914f94fe454710f590d6aa0452627a88af75 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -45,67 +45,106 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. template -__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { - for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { - for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { - const int batch = blockIdx.y; - const int c_out = blockIdx.x; - - const int c_in = c_out / filter_multiplier; - const T* weight = filter_data + c_out * filter_height * filter_width; - T value = 0; - const int h_in_start = -padding_height + h_out * stride_height; - const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = h_in_start + filter_height * dilate_height; - const int w_in_end = w_in_start + filter_width * dilate_width; - - int in_offset; - if (data_layout != DataLayout::kNHWC) { - in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; - } else { - in_offset = batch * input_height * input_width * input_channels; +__device__ __inline__ void KernelDepthwiseConvNCHW( + ARG_DEFINE_KernelDepthwiseConv) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= (output_channels * batch_size * output_height * output_width)) + return; + + const int w_out = idx % output_width; + const int h_out = (idx / output_width) % output_height; + const int c_out = (idx / output_width / output_height) % output_channels; + const int batch = idx / output_width / output_height / output_channels; + + const int c_in = c_out / filter_multiplier; + const T* weight = filter_data + c_out * filter_height * filter_width; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height * dilate_height; + const int w_in_end = w_in_start + filter_width * dilate_width; + + int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + int weight_offset = 0; + +#pragma unroll + for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { +#pragma unroll + for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { + if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { + int offset = in_offset + h_in * input_width + w_in; + T in_data = input_data[offset]; + if (fuse_relu_before_conv) { + value += weight[weight_offset] * max(0.0f, in_data); + } else { + value += weight[weight_offset] * in_data; + } } + weight_offset++; + } + } + int index = batch * output_channels * output_height * output_width + + c_out * output_height * output_width + h_out * output_width + + w_out; + output_data[index] = value; +} - const int h_end = h_in_end < input_height ? h_in_end : input_height; - const int w_end = w_in_end < input_width ? w_in_end : input_width; - const int h_start = h_in_start > 0 ? h_in_start : 0; - const int w_start = w_in_start > 0 ? w_in_start : 0; - int weight_offset = 0; - - for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { - for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { - if (h_in >= h_start && h_in < h_end && w_in >= w_start && - w_in < w_end) { - int offset; - if (data_layout != DataLayout::kNHWC) { - offset = in_offset + h_in * input_width + w_in; - } else { - offset = in_offset + - (h_in * input_width + w_in) * input_channels + c_in; - } - if (fuse_relu_before_conv) { - value += weight[weight_offset] * max(0.0f, input_data[offset]); - } else { - value += weight[weight_offset] * input_data[offset]; - } - } - weight_offset++; +// A Cuda kernel to compute the depthwise convolution forward pass +// in NHWC format. +template +__device__ __inline__ void KernelDepthwiseConvNHWC( + ARG_DEFINE_KernelDepthwiseConv) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= (output_channels * batch_size * output_height * output_width)) + return; + + const int c_out = idx % output_channels; + const int w_out = (idx / output_channels) % output_width; + const int h_out = (idx / output_channels / output_width) % output_height; + const int batch = idx / output_width / output_height / output_channels; + + const int c_in = c_out / filter_multiplier; + const T* weight = filter_data + c_out * filter_height * filter_width; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height * dilate_height; + const int w_in_end = w_in_start + filter_width * dilate_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + int weight_offset = 0; + +#pragma unroll + for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { +#pragma unroll + for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { + if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { + int offset = ((batch * input_height + h_in) * input_width + w_in) * + output_channels + + c_in; + T in_data = input_data[offset]; + if (fuse_relu_before_conv) { + value += weight[weight_offset] * max(0.0f, in_data); + } else { + value += weight[weight_offset] * in_data; } } - int index; - if (data_layout != DataLayout::kNHWC) { - index = ((batch * gridDim.x + c_out) * output_height + h_out) * - output_width + - w_out; - } else { - index = ((batch * output_height + h_out) * output_width + w_out) * - gridDim.x + - c_out; - } - output_data[index] = value; + weight_offset++; } } + int index = batch * output_channels * output_height * output_width + + h_out * output_width * output_channels + w_out * output_channels + + c_out; + output_data[index] = value; } template @@ -183,36 +222,37 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( template __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { - if (c_filter_multiplier == 0) { - if (c_filter == -1) - KernelDepthwiseConv( - input_data, filter_data, batch_size, output_channels, output_height, - output_width, input_channels, input_height, input_width, - filter_multiplier, filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, dilate_height, - dilate_width, output_data, data_layout); - else - KernelDepthwiseConvCFilter( - input_data, filter_data, batch_size, output_channels, output_height, - output_width, input_channels, input_height, input_width, - filter_multiplier, filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, dilate_height, - dilate_width, output_data, data_layout); - } else { - if (c_filter == -1) - KernelDepthwiseConv( + int final_filter_multiplier = filter_multiplier; + int h_stride = stride_height; + int w_stride = stride_width; + if (c_filter_multiplier != 0) { + final_filter_multiplier = c_filter_multiplier; + h_stride = c_stride; + w_stride = c_stride; + } + if (c_filter == -1) { + if (data_layout == DataLayout::kNCHW) { + KernelDepthwiseConvNCHW( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, - c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, - padding_height, padding_width, dilate_height, dilate_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, output_data, data_layout); - else - KernelDepthwiseConvCFilter( + } else { + KernelDepthwiseConvNHWC( input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, - c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, - padding_height, padding_width, dilate_height, dilate_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, output_data, data_layout); + } + } else { + KernelDepthwiseConvCFilter( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + final_filter_multiplier, filter_height, filter_width, h_stride, + w_stride, padding_height, padding_width, dilate_height, dilate_width, + output_data, data_layout); } } @@ -564,12 +604,22 @@ class DepthwiseConvFunctor<<>>( \