未验证 提交 330b1a0a 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of depthwise_conv_fwd (#46287)

* Optimize performance of depthwise_conv_fwd

* fix
上级 22fe4f03
......@@ -139,56 +139,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) {
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template <typename T, bool fuse_relu_before_conv>
template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNCHW(
ARG_DEFINE_KernelDepthwiseConv) {
const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int fh_size = c_filter != -1 ? c_filter : filter_height;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= (output_channels * batch_size * output_height * output_width))
return;
const int w_out = idx % output_width;
const int h_out = (idx / output_width) % output_height;
const int c_out = (idx / output_width / output_height) % output_channels;
const int batch = idx / output_width / output_height / output_channels;
int tmp_1 = idx / output_width;
const int w_out = idx - tmp_1 * output_width;
int tmp_2 = tmp_1 / output_height;
const int h_out = tmp_1 - tmp_2 * output_height;
tmp_1 = tmp_2;
tmp_2 = tmp_1 / output_channels;
const int c_out = tmp_1 - tmp_2 * output_channels;
const int batch = tmp_2;
const int c_in = c_out / filter_multiplier;
const T* weight = filter_data + c_out * filter_height * filter_width;
T value(0);
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;
int weight_offset = c_out * filter_height * filter_width;
int h_in_start = -padding_height + h_out * stride_height;
int w_in_start = -padding_width + w_out * stride_width;
#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
for (int fh = 0, h_in = h_in_start; fh < fh_size;
fh++, h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
for (int fw = 0, w_in = w_in_start; fw < fw_size;
fw++, w_in += dilate_width) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset];
if (fuse_relu_before_conv) {
value += weight[weight_offset] *
T(max(0.0f, static_cast<double>(in_data)));
value += filter_data[weight_offset] *
static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else {
value += weight[weight_offset] * in_data;
value += filter_data[weight_offset] * in_data;
}
}
weight_offset++;
}
}
int index = batch * output_channels * output_height * output_width +
c_out * output_height * output_width + h_out * output_width +
w_out;
output_data[index] = value;
output_data[idx] = value;
}
// A Cuda kernel to compute the depthwise convolution forward pass
......@@ -229,7 +226,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) {
value += weight[0] * T(max(0.0f, static_cast<double>(in_data)));
value += weight[0] *
static_cast<T>(max(0.0f, static_cast<double>(in_data)));
} else {
value += weight[0] * in_data;
}
......@@ -282,7 +280,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
T(max(0.0f, static_cast<double>(input_data[offset])));
static_cast<T>(
max(0.0f, static_cast<double>(input_data[offset])));
} else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
}
......@@ -338,7 +337,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
T(max(0.0, static_cast<double>(input_data[offset])));
static_cast<T>(
max(0.0, static_cast<double>(input_data[offset])));
} else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
}
......@@ -368,7 +368,8 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
}
if (c_filter == -1) {
if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvNCHW<T, fuse_relu_before_conv>(input_data,
KernelDepthwiseConvNCHW<T, c_filter, fuse_relu_before_conv>(
input_data,
filter_data,
batch_size,
output_channels,
......@@ -881,7 +882,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
T(max(0.0f, static_cast<double>(input_data[input_id])));
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
......@@ -942,7 +944,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
T(max(0.0f, static_cast<double>(input_data[input_id])));
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id];
......@@ -1014,7 +1017,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T s(0);
if (fuse_relu_before_conv) {
s = output_grad_data[output_id] *
T(max(0.0f, static_cast<double>(input_data[input_id])));
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s = output_grad_data[output_id] * input_data[input_id];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册