未验证 提交 0d8b222b 编写于 作者: W wangchaochaohu 提交者: GitHub

Optimize the depthwise op test=develop (#22265)

上级 325f0722
...@@ -45,67 +45,106 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { ...@@ -45,67 +45,106 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { __device__ __inline__ void KernelDepthwiseConvNCHW(
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { ARG_DEFINE_KernelDepthwiseConv) {
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { int idx = threadIdx.x + blockIdx.x * blockDim.x;
const int batch = blockIdx.y; if (idx >= (output_channels * batch_size * output_height * output_width))
const int c_out = blockIdx.x; return;
const int c_in = c_out / filter_multiplier; const int w_out = idx % output_width;
const T* weight = filter_data + c_out * filter_height * filter_width; const int h_out = (idx / output_width) % output_height;
T value = 0; const int c_out = (idx / output_width / output_height) % output_channels;
const int h_in_start = -padding_height + h_out * stride_height; const int batch = idx / output_width / output_height / output_channels;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height; const int c_in = c_out / filter_multiplier;
const int w_in_end = w_in_start + filter_width * dilate_width; const T* weight = filter_data + c_out * filter_height * filter_width;
T value = 0;
int in_offset; const int h_in_start = -padding_height + h_out * stride_height;
if (data_layout != DataLayout::kNHWC) { const int w_in_start = -padding_width + w_out * stride_width;
in_offset = const int h_in_end = h_in_start + filter_height * dilate_height;
((batch * input_channels + c_in) * input_height) * input_width; const int w_in_end = w_in_start + filter_width * dilate_width;
} else {
in_offset = batch * input_height * input_width * input_channels; int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset];
if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, in_data);
} else {
value += weight[weight_offset] * in_data;
}
} }
weight_offset++;
}
}
int index = batch * output_channels * output_height * output_width +
c_out * output_height * output_width + h_out * output_width +
w_out;
output_data[index] = value;
}
const int h_end = h_in_end < input_height ? h_in_end : input_height; // A Cuda kernel to compute the depthwise convolution forward pass
const int w_end = w_in_end < input_width ? w_in_end : input_width; // in NHWC format.
const int h_start = h_in_start > 0 ? h_in_start : 0; template <typename T, bool fuse_relu_before_conv>
const int w_start = w_in_start > 0 ? w_in_start : 0; __device__ __inline__ void KernelDepthwiseConvNHWC(
int weight_offset = 0; ARG_DEFINE_KernelDepthwiseConv) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { if (idx >= (output_channels * batch_size * output_height * output_width))
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { return;
if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
w_in < w_end) { const int c_out = idx % output_channels;
int offset; const int w_out = (idx / output_channels) % output_width;
if (data_layout != DataLayout::kNHWC) { const int h_out = (idx / output_channels / output_width) % output_height;
offset = in_offset + h_in * input_width + w_in; const int batch = idx / output_width / output_height / output_channels;
} else {
offset = in_offset + const int c_in = c_out / filter_multiplier;
(h_in * input_width + w_in) * input_channels + c_in; const T* weight = filter_data + c_out * filter_height * filter_width;
} T value = 0;
if (fuse_relu_before_conv) { const int h_in_start = -padding_height + h_out * stride_height;
value += weight[weight_offset] * max(0.0f, input_data[offset]); const int w_in_start = -padding_width + w_out * stride_width;
} else { const int h_in_end = h_in_start + filter_height * dilate_height;
value += weight[weight_offset] * input_data[offset]; const int w_in_end = w_in_start + filter_width * dilate_width;
}
} const int h_end = h_in_end < input_height ? h_in_end : input_height;
weight_offset++; const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
int offset = ((batch * input_height + h_in) * input_width + w_in) *
output_channels +
c_in;
T in_data = input_data[offset];
if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, in_data);
} else {
value += weight[weight_offset] * in_data;
} }
} }
int index; weight_offset++;
if (data_layout != DataLayout::kNHWC) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out;
} else {
index = ((batch * output_height + h_out) * output_width + w_out) *
gridDim.x +
c_out;
}
output_data[index] = value;
} }
} }
int index = batch * output_channels * output_height * output_width +
h_out * output_width * output_channels + w_out * output_channels +
c_out;
output_data[index] = value;
} }
template <typename T, int c_filter, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
...@@ -183,36 +222,37 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -183,36 +222,37 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
template <typename T, int c_filter_multiplier, int c_stride, int c_filter, template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
bool fuse_relu_before_conv> bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
if (c_filter_multiplier == 0) { int final_filter_multiplier = filter_multiplier;
if (c_filter == -1) int h_stride = stride_height;
KernelDepthwiseConv<T, fuse_relu_before_conv>( int w_stride = stride_width;
input_data, filter_data, batch_size, output_channels, output_height, if (c_filter_multiplier != 0) {
output_width, input_channels, input_height, input_width, final_filter_multiplier = c_filter_multiplier;
filter_multiplier, filter_height, filter_width, stride_height, h_stride = c_stride;
stride_width, padding_height, padding_width, dilate_height, w_stride = c_stride;
dilate_width, output_data, data_layout); }
else if (c_filter == -1) {
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>( if (data_layout == DataLayout::kNCHW) {
input_data, filter_data, batch_size, output_channels, output_height, KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data, data_layout);
} else {
if (c_filter == -1)
KernelDepthwiseConv<T, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, final_filter_multiplier, filter_height, filter_width, h_stride,
padding_height, padding_width, dilate_height, dilate_width, w_stride, padding_height, padding_width, dilate_height, dilate_width,
output_data, data_layout); output_data, data_layout);
else } else {
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>( KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, final_filter_multiplier, filter_height, filter_width, h_stride,
padding_height, padding_width, dilate_height, dilate_width, w_stride, padding_height, padding_width, dilate_height, dilate_width,
output_data, data_layout); output_data, data_layout);
}
} else {
KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
final_filter_multiplier, filter_height, filter_width, h_stride,
w_stride, padding_height, padding_width, dilate_height, dilate_width,
output_data, data_layout);
} }
} }
...@@ -564,12 +604,22 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -564,12 +604,22 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
dim3 threads(std::min(output_width, thread), blocks, 1); dim3 threads(std::min(output_width, thread), blocks, 1);
dim3 grid(output_channels, batch_size, 1); dim3 grid(output_channels, batch_size, 1);
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_output =
batch_size * output_channels * output_height * output_width;
int block_size = 512;
#define check_case(c_filter_multiplier, c_stride, c_filter) \ #define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \ if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \ filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride && \ stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = (nums_output + block_size - 1) / block_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvSp< \ KernelDepthwiseConvSp< \
T, c_filter_multiplier, c_stride, c_filter, \ T, c_filter_multiplier, c_stride, c_filter, \
fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \ fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册