未验证 提交 0d8b222b 编写于 作者: W wangchaochaohu 提交者: GitHub

Optimize the depthwise op test=develop (#22265)

上级 325f0722
...@@ -45,11 +45,16 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { ...@@ -45,11 +45,16 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { __device__ __inline__ void KernelDepthwiseConvNCHW(
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { ARG_DEFINE_KernelDepthwiseConv) {
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { int idx = threadIdx.x + blockIdx.x * blockDim.x;
const int batch = blockIdx.y; if (idx >= (output_channels * batch_size * output_height * output_width))
const int c_out = blockIdx.x; return;
const int w_out = idx % output_width;
const int h_out = (idx / output_width) % output_height;
const int c_out = (idx / output_width / output_height) % output_channels;
const int batch = idx / output_width / output_height / output_channels;
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width; const T* weight = filter_data + c_out * filter_height * filter_width;
...@@ -59,13 +64,8 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -59,13 +64,8 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
const int h_in_end = h_in_start + filter_height * dilate_height; const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width; const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset; int in_offset =
if (data_layout != DataLayout::kNHWC) {
in_offset =
((batch * input_channels + c_in) * input_height) * input_width; ((batch * input_channels + c_in) * input_height) * input_width;
} else {
in_offset = batch * input_height * input_width * input_channels;
}
const int h_end = h_in_end < input_height ? h_in_end : input_height; const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width; const int w_end = w_in_end < input_width ? w_in_end : input_width;
...@@ -73,39 +73,78 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -73,39 +73,78 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
const int w_start = w_in_start > 0 ? w_in_start : 0; const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0; int weight_offset = 0;
#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
w_in < w_end) { int offset = in_offset + h_in * input_width + w_in;
int offset; T in_data = input_data[offset];
if (data_layout != DataLayout::kNHWC) {
offset = in_offset + h_in * input_width + w_in;
} else {
offset = in_offset +
(h_in * input_width + w_in) * input_channels + c_in;
}
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, input_data[offset]); value += weight[weight_offset] * max(0.0f, in_data);
} else { } else {
value += weight[weight_offset] * input_data[offset]; value += weight[weight_offset] * in_data;
} }
} }
weight_offset++; weight_offset++;
} }
} }
int index; int index = batch * output_channels * output_height * output_width +
if (data_layout != DataLayout::kNHWC) { c_out * output_height * output_width + h_out * output_width +
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out; w_out;
output_data[index] = value;
}
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
ARG_DEFINE_KernelDepthwiseConv) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= (output_channels * batch_size * output_height * output_width))
return;
const int c_out = idx % output_channels;
const int w_out = (idx / output_channels) % output_width;
const int h_out = (idx / output_channels / output_width) % output_height;
const int batch = idx / output_width / output_height / output_channels;
const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width;
T value = 0;
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
int offset = ((batch * input_height + h_in) * input_width + w_in) *
output_channels +
c_in;
T in_data = input_data[offset];
if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, in_data);
} else { } else {
index = ((batch * output_height + h_out) * output_width + w_out) * value += weight[weight_offset] * in_data;
gridDim.x +
c_out;
} }
output_data[index] = value; }
weight_offset++;
} }
} }
int index = batch * output_channels * output_height * output_width +
h_out * output_width * output_channels + w_out * output_channels +
c_out;
output_data[index] = value;
} }
template <typename T, int c_filter, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
...@@ -183,35 +222,36 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -183,35 +222,36 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
template <typename T, int c_filter_multiplier, int c_stride, int c_filter, template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
bool fuse_relu_before_conv> bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
if (c_filter_multiplier == 0) { int final_filter_multiplier = filter_multiplier;
if (c_filter == -1) int h_stride = stride_height;
KernelDepthwiseConv<T, fuse_relu_before_conv>( int w_stride = stride_width;
if (c_filter_multiplier != 0) {
final_filter_multiplier = c_filter_multiplier;
h_stride = c_stride;
w_stride = c_stride;
}
if (c_filter == -1) {
if (data_layout == DataLayout::kNCHW) {
KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height, final_filter_multiplier, filter_height, filter_width, h_stride,
stride_width, padding_height, padding_width, dilate_height, w_stride, padding_height, padding_width, dilate_height, dilate_width,
dilate_width, output_data, data_layout); output_data, data_layout);
else
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data, data_layout);
} else { } else {
if (c_filter == -1) KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
KernelDepthwiseConv<T, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, final_filter_multiplier, filter_height, filter_width, h_stride,
padding_height, padding_width, dilate_height, dilate_width, w_stride, padding_height, padding_width, dilate_height, dilate_width,
output_data, data_layout); output_data, data_layout);
else }
} else {
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>( KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, final_filter_multiplier, filter_height, filter_width, h_stride,
padding_height, padding_width, dilate_height, dilate_width, w_stride, padding_height, padding_width, dilate_height, dilate_width,
output_data, data_layout); output_data, data_layout);
} }
} }
...@@ -564,12 +604,22 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -564,12 +604,22 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
dim3 threads(std::min(output_width, thread), blocks, 1); dim3 threads(std::min(output_width, thread), blocks, 1);
dim3 grid(output_channels, batch_size, 1); dim3 grid(output_channels, batch_size, 1);
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_output =
batch_size * output_channels * output_height * output_width;
int block_size = 512;
#define check_case(c_filter_multiplier, c_stride, c_filter) \ #define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \ if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \ filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride && \ stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = (nums_output + block_size - 1) / block_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvSp< \ KernelDepthwiseConvSp< \
T, c_filter_multiplier, c_stride, c_filter, \ T, c_filter_multiplier, c_stride, c_filter, \
fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \ fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册