未验证 提交 f17a73e9 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of depthwise_conv_bwd (#46362)

* Optimize performance of depthwise_conv_bwd

* fix
上级 2e231402
...@@ -469,60 +469,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -469,60 +469,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
const int dilate_height, const int dilate_width, \ const int dilate_height, const int dilate_width, \
T *const input_grad_data T *const input_grad_data
template <typename T, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW( __device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
ARG_DEFINE_KernelDepthwiseConvInputGrad) { ARG_DEFINE_KernelDepthwiseConvInputGrad) {
const int batch = blockIdx.y; const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int c_in = blockIdx.x; const int fh_size = c_filter != -1 ? c_filter : filter_height;
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { if (idx >= batch_size * input_channels * input_height * input_width) {
const int c_out_start = c_in * filter_multiplier; return;
int h_out_start = }
h_in - (filter_height - 1) * dilate_height + padding_height;
int h_out_end = h_in + padding_height;
int w_out_start =
w_in - (filter_width - 1) * dilate_width + padding_width;
int w_out_end = w_in + padding_width;
T value(0);
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
if (input_data[index] <= T(0)) { if (input_data[idx] <= static_cast<T>(0.0f)) {
input_grad_data[index] = 0; input_grad_data[idx] = 0;
continue; return;
} }
} }
for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; int tmp_1 = idx / input_width;
c_out++) { const int w_in = idx - tmp_1 * input_width;
int filter_offset = (c_out + 1) * filter_height * filter_width; int tmp_2 = tmp_1 / input_height;
for (int h_out = h_out_start; h_out <= h_out_end; const int h_in = tmp_1 - tmp_2 * input_height;
h_out += dilate_height) { tmp_1 = tmp_2;
for (int w_out = w_out_start; w_out <= w_out_end; tmp_2 = tmp_1 / input_channels;
w_out += dilate_width) { const int c_in = tmp_1 - tmp_2 * input_channels;
filter_offset--; const int batch = tmp_2;
int s_h_out = h_out / stride_height;
int s_w_out = w_out / stride_width; T value(0);
if (h_out % stride_height == 0 && w_out % stride_width == 0 && for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) {
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && int c_out = c_in * filter_multiplier + c_mul;
s_w_out < output_width) { int filter_offset = c_out * filter_height * filter_width;
#pragma unroll
for (int fh = 0; fh < fh_size; ++fh) {
#pragma unroll
for (int fw = 0; fw < fw_size; ++fw) {
int h_out = h_in + padding_height - fh * dilate_height;
int w_out = w_in + padding_width - fw * dilate_width;
if ((h_out - h_out / stride_height * stride_height == 0) &&
(w_out - w_out / stride_width * stride_width == 0)) {
h_out /= stride_height;
w_out /= stride_width;
if (h_out >= 0 && h_out < output_height && w_out >= 0 &&
w_out < output_width) {
int output_grad_offset = int output_grad_offset =
((batch * output_channels + c_out) * output_height + ((batch * output_channels + c_out) * output_height + h_out) *
s_h_out) *
output_width + output_width +
s_w_out; w_out;
value += output_grad_data[output_grad_offset] * value += output_grad_data[output_grad_offset] *
filter_data[filter_offset]; filter_data[filter_offset];
} }
} }
filter_offset++;
} }
} }
input_grad_data[index] = value;
}
} }
input_grad_data[idx] = value;
} }
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
...@@ -735,7 +737,7 @@ __global__ void KernelDepthwiseConvInputGradSp( ...@@ -735,7 +737,7 @@ __global__ void KernelDepthwiseConvInputGradSp(
if (c_filter_multiplier == 0 || c_filter == -1) { if (c_filter_multiplier == 0 || c_filter == -1) {
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>( KernelDepthwiseConvInputGradNCHW<T, c_filter, fuse_relu_before_conv>(
input_data, input_data,
output_grad_data, output_grad_data,
filter_data, filter_data,
...@@ -1247,8 +1249,7 @@ class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1247,8 +1249,7 @@ class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size); batch_size);
} }
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_output = int nums_output = output->numel();
batch_size * output_channels * output_height * output_width;
#ifdef __HIPCC__ #ifdef __HIPCC__
int block_size = 256; int block_size = 256;
#else #else
...@@ -1421,6 +1422,13 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1421,6 +1422,13 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size); batch_size);
} }
int filter_multiplier = output_channels / input_channels; int filter_multiplier = output_channels / input_channels;
int nums_input = input_grad->numel();
#ifdef __HIPCC__
int block_size = 256;
#else
int block_size = 512;
#endif
int grid_size = (nums_input + block_size - 1) / block_size;
#define check_case(c_filter_multiplier, c_stride, c_filter) \ #define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \ if (c_filter_multiplier == 0 || \
...@@ -1429,6 +1437,11 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> { ...@@ -1429,6 +1437,11 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
if (data_layout != DataLayout::kNHWC) { \ if (data_layout != DataLayout::kNHWC) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = grid_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvInputGradSp<T, \ KernelDepthwiseConvInputGradSp<T, \
c_filter_multiplier, \ c_filter_multiplier, \
c_stride, \ c_stride, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册