未验证 提交 976af0da 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of depthwise_conv (#46896)

Optimize performance of depthwise_conv

Config: input[2048, 1024, 4, 4], filter[1024, 1, 4, 4], stride=1, pad=0, dilation=1
上级 7eef05c2
...@@ -87,43 +87,36 @@ class DepthwiseConvFilterGradFunctor { ...@@ -87,43 +87,36 @@ class DepthwiseConvFilterGradFunctor {
const DataLayout data_layout = DataLayout::kNCHW); const DataLayout data_layout = DataLayout::kNCHW);
}; };
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
template <typename T> template <typename T>
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) { __forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
typedef cub::WarpReduce<T> WarpReduce; for (int mask = HALF_WARP; mask > 0; mask >>= 1)
typename WarpReduce::TempStorage temp_storage; val += platform::CudaShuffleDownSync(lane_mask, val, mask);
val = WarpReduce(temp_storage).Sum(val, warp_size);
return val; return val;
} }
template <typename T> template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) { __forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) {
static __shared__ T shared[32]; static __shared__ T shared[WARP_SIZE];
int thread_id = threadIdx.x + threadIdx.y * blockDim.x + int tid = threadIdx.y * blockDim.x + threadIdx.x;
threadIdx.z * blockDim.x * blockDim.y; int lane = tid & 0x1f;
int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); int wid = tid >> 5;
int lane = thread_id % warp_size;
int wid = thread_id / warp_size; val = WarpReduceSum<T>(val, mask);
val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
int block_size = blockDim.x * blockDim.y * blockDim.z;
if (thread_id < (block_size - 1) / warp_size + 1) {
val = shared[lane];
} else {
val = static_cast<T>(0);
}
if (wid == 0) {
val = WarpReduceSum(val, warp_size); // Final reduce within first warp
}
__syncthreads(); __syncthreads();
if (thread_id != 0) { if (lane == 0) shared[wid] = val;
val = static_cast<T>(0);
} __syncthreads();
// align block_span to WARP_SIZE
int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5;
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = WarpReduceSum<T>(val, mask);
return val; return val;
} }
...@@ -139,55 +132,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) { ...@@ -139,55 +132,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) {
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNCHW( __device__ __inline__ void KernelDepthwiseConvNCHW(
ARG_DEFINE_KernelDepthwiseConv) { ARG_DEFINE_KernelDepthwiseConv) {
const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int fh_size = c_filter != -1 ? c_filter : filter_height;
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= (output_channels * batch_size * output_height * output_width)) if (idx >= (output_channels * batch_size * output_height * output_width))
return; return;
const int w_out = idx % output_width; int tmp_1 = idx / output_width;
const int h_out = (idx / output_width) % output_height; const int w_out = idx - tmp_1 * output_width;
const int c_out = (idx / output_width / output_height) % output_channels; int tmp_2 = tmp_1 / output_height;
const int batch = idx / output_width / output_height / output_channels; const int h_out = tmp_1 - tmp_2 * output_height;
tmp_1 = tmp_2;
tmp_2 = tmp_1 / output_channels;
const int c_out = tmp_1 - tmp_2 * output_channels;
const int batch = tmp_2;
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width;
T value(0); T value(0);
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset = int in_offset =
((batch * input_channels + c_in) * input_height) * input_width; ((batch * input_channels + c_in) * input_height) * input_width;
int weight_offset = c_out * filter_height * filter_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height; int h_in_start = -padding_height + h_out * stride_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width; int w_in_start = -padding_width + w_out * stride_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
#pragma unroll #pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { for (int fh = 0, h_in = h_in_start; fh < fh_size;
fh++, h_in += dilate_height) {
#pragma unroll #pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { for (int fw = 0, w_in = w_in_start; fw < fw_size;
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { fw++, w_in += dilate_width) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset]; T in_data = input_data[offset];
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[weight_offset] * T(max(0.0f, double(in_data))); value += filter_data[weight_offset] *
static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[weight_offset] * in_data; value += filter_data[weight_offset] * in_data;
} }
} }
weight_offset++; weight_offset++;
} }
} }
int index = batch * output_channels * output_height * output_width + output_data[idx] = value;
c_out * output_height * output_width + h_out * output_width +
w_out;
output_data[index] = value;
} }
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
...@@ -228,7 +219,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ...@@ -228,7 +219,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset]; T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out; const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[0] * T(max(0.0f, double(in_data))); value += weight[0] *
static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[0] * in_data; value += weight[0] * in_data;
} }
...@@ -281,7 +273,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ...@@ -281,7 +273,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0f, double(input_data[offset]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -337,7 +330,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ...@@ -337,7 +330,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in; in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0, double(input_data[offset]))); static_cast<T>(
max(0.0, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -367,7 +361,8 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -367,7 +361,8 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
} }
if (c_filter == -1) { if (c_filter == -1) {
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(input_data, KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
input_data,
filter_data, filter_data,
batch_size, batch_size,
output_channels, output_channels,
...@@ -467,60 +462,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -467,60 +462,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
const int dilate_height, const int dilate_width, \ const int dilate_height, const int dilate_width, \
T *const input_grad_data T *const input_grad_data
template <typename T, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW( __device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
ARG_DEFINE_KernelDepthwiseConvInputGrad) { ARG_DEFINE_KernelDepthwiseConvInputGrad) {
const int batch = blockIdx.y; const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int c_in = blockIdx.x; const int fh_size = c_filter != -1 ? c_filter : filter_height;
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { if (idx >= batch_size * input_channels * input_height * input_width) {
const int c_out_start = c_in * filter_multiplier; return;
int h_out_start = }
h_in - (filter_height - 1) * dilate_height + padding_height;
int h_out_end = h_in + padding_height;
int w_out_start =
w_in - (filter_width - 1) * dilate_width + padding_width;
int w_out_end = w_in + padding_width;
T value(0);
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= T(0)) { if (input_data[idx] <= static_cast<T>(0.0f)) {
input_grad_data[index] = 0; input_grad_data[idx] = 0;
continue; return;
} }
} }
for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; int tmp_1 = idx / input_width;
c_out++) { const int w_in = idx - tmp_1 * input_width;
int filter_offset = (c_out + 1) * filter_height * filter_width; int tmp_2 = tmp_1 / input_height;
for (int h_out = h_out_start; h_out <= h_out_end; const int h_in = tmp_1 - tmp_2 * input_height;
h_out += dilate_height) { tmp_1 = tmp_2;
for (int w_out = w_out_start; w_out <= w_out_end; tmp_2 = tmp_1 / input_channels;
w_out += dilate_width) { const int c_in = tmp_1 - tmp_2 * input_channels;
filter_offset--; const int batch = tmp_2;
int s_h_out = h_out / stride_height;
int s_w_out = w_out / stride_width; T value(0);
if (h_out % stride_height == 0 && w_out % stride_width == 0 && for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) {
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && int c_out = c_in * filter_multiplier + c_mul;
s_w_out < output_width) { int filter_offset = c_out * filter_height * filter_width;
#pragma unroll
for (int fh = 0; fh < fh_size; ++fh) {
#pragma unroll
for (int fw = 0; fw < fw_size; ++fw) {
int h_out = h_in + padding_height - fh * dilate_height;
int w_out = w_in + padding_width - fw * dilate_width;
if ((h_out - h_out / stride_height * stride_height == 0) &&
(w_out - w_out / stride_width * stride_width == 0)) {
h_out /= stride_height;
w_out /= stride_width;
if (h_out >= 0 && h_out < output_height && w_out >= 0 &&
w_out < output_width) {
int output_grad_offset = int output_grad_offset =
((batch * output_channels + c_out) * output_height + ((batch * output_channels + c_out) * output_height + h_out) *
s_h_out) *
output_width + output_width +
s_w_out; w_out;
value += output_grad_data[output_grad_offset] * value += output_grad_data[output_grad_offset] *
filter_data[filter_offset]; filter_data[filter_offset];
} }
} }
filter_offset++;
} }
} }
input_grad_data[index] = value;
}
} }
input_grad_data[idx] = value;
} }
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
...@@ -733,7 +730,7 @@ __global__ void KernelDepthwiseConvInputGradSp( ...@@ -733,7 +730,7 @@ __global__ void KernelDepthwiseConvInputGradSp(
if (c_filter_multiplier == 0 || c_filter == -1) { if (c_filter_multiplier == 0 || c_filter == -1) {
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>( KernelDepthwiseConvInputGradNCHW<T, c_filter, fuse_relu_before_conv>(
input_data, input_data,
output_grad_data, output_grad_data,
filter_data, filter_data,
...@@ -854,44 +851,81 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -854,44 +851,81 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
const int dilate_height, const int dilate_height,
const int dilate_width, const int dilate_width,
T* filter_grad_data) { T* filter_grad_data) {
T s(0); T f_grad(0);
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; const bool loop_batch = output_height * output_width >= WARP_SIZE;
for (int image_w = threadIdx.x; image_w < output_width; int kw_id = blockIdx.x;
image_w += blockDim.x) { int kh_id = blockIdx.y;
for (int bid = 0; bid < num; bid++) { int oc_id = blockIdx.z;
for (int image_h = threadIdx.y; image_h < output_height; int ic_id = oc_id / filter_multiplier;
image_h += blockDim.y) { int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
int kernel_id = blockIdx.z;
int kernel_h = blockIdx.y * dilate_height - padding_height; const int ohw = output_height * output_width;
int kernel_w = blockIdx.x * dilate_width - padding_width; const int onhw = num * ohw;
const int h_offset = kh_id * dilate_height - padding_height;
int image_hk = image_h * stride_height + kernel_h; const int w_offset = kw_id * dilate_width - padding_width;
int image_wk = image_w * stride_width + kernel_w;
if (image_hk < 0 || image_hk >= input_height) continue; if (loop_batch) {
if (image_wk < 0 || image_wk >= input_width) continue; for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) {
#define gaid(N, C, H, W) \ for (int bid = 0; bid < num; ++bid) {
((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) {
int input_id = ((bid * (gridDim.z / filter_multiplier) + int i_h = og_h * stride_height + h_offset;
kernel_id / filter_multiplier) * int i_w = og_w * stride_width + w_offset;
input_height +
image_hk) * if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
int input_offset =
((bid * input_channels + ic_id) * input_height + i_h) *
input_width +
i_w;
int output_grad_offset =
((bid * output_channels + oc_id) * output_height + og_h) *
output_width +
og_w;
if (fuse_relu_before_conv) {
f_grad +=
output_grad_data[output_grad_offset] *
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_offset])));
} else {
f_grad += output_grad_data[output_grad_offset] *
input_data[input_offset];
}
}
}
}
}
} else {
for (int id = threadIdx.x; id < onhw; id += blockDim.x) {
int bid = id / ohw;
int og_hw = id - bid * ohw;
int og_h = og_hw / output_width;
int og_w = og_hw - og_h * output_width;
int i_h = og_h * stride_height + h_offset;
int i_w = og_w * stride_width + w_offset;
if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) {
int input_offset =
((bid * input_channels + ic_id) * input_height + i_h) *
input_width + input_width +
image_wk; i_w;
int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * f_grad += output_grad_data[output_grad_offset] *
T(max(0.0f, double(input_data[input_id]))); static_cast<T>(max(
0.0f, static_cast<double>(input_data[input_offset])));
} else { } else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * f_grad +=
input_data[input_id]; output_grad_data[output_grad_offset] * input_data[input_offset];
} }
#undef gaid
} }
} }
} }
T val = BlockReduceSum(s); T val = BlockReduceSum<T>(f_grad);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val); if (threadIdx.x == 0 && threadIdx.y == 0) {
filter_grad_data[idx] = val;
}
} }
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
...@@ -941,7 +975,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( ...@@ -941,7 +975,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier; kernel_id / filter_multiplier;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
T(max(0.0f, double(input_data[input_id]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id]; input_data[input_id];
...@@ -1013,7 +1048,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( ...@@ -1013,7 +1048,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T s(0); T s(0);
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s = output_grad_data[output_id] * s = output_grad_data[output_id] *
T(max(0.0f, double(input_data[input_id]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s = output_grad_data[output_id] * input_data[input_id]; s = output_grad_data[output_id] * input_data[input_id];
} }
...@@ -1242,8 +1278,7 @@ class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1242,8 +1278,7 @@ class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size); batch_size);
} }
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_output = int nums_output = output->numel();
batch_size * output_channels * output_height * output_width;
#ifdef __HIPCC__ #ifdef __HIPCC__
int block_size = 256; int block_size = 256;
#else #else
...@@ -1416,6 +1451,13 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1416,6 +1451,13 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size); batch_size);
} }
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_input = input_grad->numel();
#ifdef __HIPCC__
int block_size = 256;
#else
int block_size = 512;
#endif
int grid_size = (nums_input + block_size - 1) / block_size;
#define check_case(c_filter_multiplier, c_stride, c_filter) \ #define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \ if (c_filter_multiplier == 0 || \
...@@ -1424,6 +1466,11 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1424,6 +1466,11 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
if (data_layout != DataLayout::kNHWC) { \ if (data_layout != DataLayout::kNHWC) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = grid_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvInputGradSp<T, \ KernelDepthwiseConvInputGradSp<T, \
c_filter_multiplier, \ c_filter_multiplier, \
c_stride, \ c_stride, \
...@@ -1554,6 +1601,10 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext, ...@@ -1554,6 +1601,10 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext,
blocks = std::min(std::max(block_size / output_width, 1), output_height); blocks = std::min(std::max(block_size / output_width, 1), output_height);
grid = dim3(ksize_width, ksize_height, output_channels); grid = dim3(ksize_width, ksize_height, output_channels);
threads = dim3(std::min(output_width, block_size), blocks, 1); threads = dim3(std::min(output_width, block_size), blocks, 1);
if (output_height * output_width < WARP_SIZE) {
threads = dim3(
std::min(block_size, batch_size * output_height * output_width));
}
} else { } else {
blocks = std::min( blocks = std::min(
std::max(block_size / output_channels, 1), std::max(block_size / output_channels, 1),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册