未验证 提交 330b1a0a 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of depthwise_conv_fwd (#46287)

* Optimize performance of depthwise_conv_fwd

* fix
上级 22fe4f03
...@@ -139,56 +139,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) { ...@@ -139,56 +139,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) {
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T, bool fuse_relu_before_conv> template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNCHW( __device__ __inline__ void KernelDepthwiseConvNCHW(
ARG_DEFINE_KernelDepthwiseConv) { ARG_DEFINE_KernelDepthwiseConv) {
const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int fh_size = c_filter != -1 ? c_filter : filter_height;
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= (output_channels * batch_size * output_height * output_width)) if (idx >= (output_channels * batch_size * output_height * output_width))
return; return;
const int w_out = idx % output_width; int tmp_1 = idx / output_width;
const int h_out = (idx / output_width) % output_height; const int w_out = idx - tmp_1 * output_width;
const int c_out = (idx / output_width / output_height) % output_channels; int tmp_2 = tmp_1 / output_height;
const int batch = idx / output_width / output_height / output_channels; const int h_out = tmp_1 - tmp_2 * output_height;
tmp_1 = tmp_2;
tmp_2 = tmp_1 / output_channels;
const int c_out = tmp_1 - tmp_2 * output_channels;
const int batch = tmp_2;
const int c_in = c_out / filter_multiplier; const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width;
T value(0); T value(0);
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset = int in_offset =
((batch * input_channels + c_in) * input_height) * input_width; ((batch * input_channels + c_in) * input_height) * input_width;
int weight_offset = c_out * filter_height * filter_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height; int h_in_start = -padding_height + h_out * stride_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width; int w_in_start = -padding_width + w_out * stride_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
#pragma unroll #pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { for (int fh = 0, h_in = h_in_start; fh < fh_size;
fh++, h_in += dilate_height) {
#pragma unroll #pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { for (int fw = 0, w_in = w_in_start; fw < fw_size;
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { fw++, w_in += dilate_width) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset]; T in_data = input_data[offset];
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[weight_offset] * value += filter_data[weight_offset] *
T(max(0.0f, static_cast<double>(in_data))); static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[weight_offset] * in_data; value += filter_data[weight_offset] * in_data;
} }
} }
weight_offset++; weight_offset++;
} }
} }
int index = batch * output_channels * output_height * output_width + output_data[idx] = value;
c_out * output_height * output_width + h_out * output_width +
w_out;
output_data[index] = value;
} }
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
...@@ -229,7 +226,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( ...@@ -229,7 +226,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset]; T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out; const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += weight[0] * T(max(0.0f, static_cast<double>(in_data))); value += weight[0] *
static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else { } else {
value += weight[0] * in_data; value += weight[0] * in_data;
} }
...@@ -282,7 +280,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( ...@@ -282,7 +280,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in; int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0f, static_cast<double>(input_data[offset]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -338,7 +337,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( ...@@ -338,7 +337,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in; in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] * value += r_weight[h_f * c_filter + w_f] *
T(max(0.0, static_cast<double>(input_data[offset]))); static_cast<T>(
max(0.0, static_cast<double>(input_data[offset])));
} else { } else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
...@@ -368,25 +368,26 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -368,25 +368,26 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
} }
if (c_filter == -1) { if (c_filter == -1) {
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(input_data, KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
filter_data, input_data,
batch_size, filter_data,
output_channels, batch_size,
output_height, output_channels,
output_width, output_height,
input_channels, output_width,
input_height, input_channels,
input_width, input_height,
final_filter_multiplier, input_width,
filter_height, final_filter_multiplier,
filter_width, filter_height,
h_stride, filter_width,
w_stride, h_stride,
padding_height, w_stride,
padding_width, padding_height,
dilate_height, padding_width,
dilate_width, dilate_height,
output_data); dilate_width,
output_data);
} else { } else {
KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data, KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
filter_data, filter_data,
...@@ -881,7 +882,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -881,7 +882,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk; image_wk;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
T(max(0.0f, static_cast<double>(input_data[input_id]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id]; input_data[input_id];
...@@ -942,7 +944,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( ...@@ -942,7 +944,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier; kernel_id / filter_multiplier;
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
T(max(0.0f, static_cast<double>(input_data[input_id]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id]; input_data[input_id];
...@@ -1014,7 +1017,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( ...@@ -1014,7 +1017,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T s(0); T s(0);
if (fuse_relu_before_conv) { if (fuse_relu_before_conv) {
s = output_grad_data[output_id] * s = output_grad_data[output_id] *
T(max(0.0f, static_cast<double>(input_data[input_id]))); static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else { } else {
s = output_grad_data[output_id] * input_data[input_id]; s = output_grad_data[output_id] * input_data[input_id];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册