提交 5f2e8378 编写于 作者: D Dun 提交者: qingqing01

optimize depthwise conv by register memory (#13778)

* optimize depthwise conv by register memory
* test=develop
上级 5428cb99
......@@ -46,17 +46,20 @@ __forceinline__ __device__ unsigned warp_id() {
return ret;
}
#define ARG_DEFINE_KernelDepthwiseConv \
const T *const input_data, const T *const filter_data, const int batch_size, \
const int output_channels, const int output_height, \
const int output_width, const int input_channels, \
const int input_height, const int input_width, \
const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, T *const output_data
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template <typename T>
__device__ __inline__ void KernelDepthwiseConv(
const T* const input_data, const T* const filter_data, const int batch_size,
const int output_channels, const int output_height, const int output_width,
const int input_channels, const int input_height, const int input_width,
const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* const output_data) {
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
const int batch = blockIdx.y;
......@@ -97,42 +100,105 @@ __device__ __inline__ void KernelDepthwiseConv(
}
}
template <typename T, int c_filter_multiplier, int c_stride>
__global__ void KernelDepthwiseConvSp(
const T* const input_data, const T* const filter_data, const int batch_size,
const int output_channels, const int output_height, const int output_width,
const int input_channels, const int input_height, const int input_width,
const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* const output_data) {
if (c_filter_multiplier == 0)
KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels,
input_height, input_width, filter_multiplier,
filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width,
dilate_height, dilate_width, output_data);
template <typename T, int c_filter>
__device__ __inline__ void KernelDepthwiseConvCFilter(
ARG_DEFINE_KernelDepthwiseConv) {
const int kWeghtSize = c_filter * c_filter;
T r_weight[kWeghtSize];
const int batch = blockIdx.y;
const int c_out = blockIdx.x;
const T* weight = filter_data + c_out * c_filter * c_filter;
for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
else
KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels,
input_height, input_width, c_filter_multiplier,
filter_height, filter_height, c_stride, c_stride,
padding_height, padding_width, dilate_height,
dilate_width, output_data);
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
const int batch = blockIdx.y;
const int c_out = blockIdx.x;
const int c_in = c_out / filter_multiplier;
T value = 0;
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + c_filter * dilate_height;
const int w_in_end = w_in_start + c_filter * dilate_width;
const int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
h_in += dilate_height, h_f++) {
for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
w_in += dilate_width, w_f++) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
w_in < input_width) {
const int offset = in_offset + h_in * input_width + w_in;
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
}
}
}
int index =
((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
w_out;
output_data[index] = value;
}
}
}
template <typename T, int c_filter_multiplier, int c_stride, int c_filter>
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
if (c_filter_multiplier == 0) {
if (c_filter == -1)
KernelDepthwiseConv<T>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data);
else
KernelDepthwiseConvCFilter<T, c_filter>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data);
} else {
if (c_filter == -1)
KernelDepthwiseConv<T>(input_data, filter_data, batch_size,
output_channels, output_height, output_width,
input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height,
c_stride, c_stride, padding_height, padding_width,
dilate_height, dilate_width, output_data);
else
KernelDepthwiseConvCFilter<T, c_filter>(
input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
output_data);
}
}
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
#define ARG_DEFINE_KernelDepthwiseConvInputGrad \
const T *const output_grad_data, const T *const filter_data, \
const int batch_size, const int output_channels, \
const int output_height, const int output_width, \
const int input_channels, const int input_height, const int input_width, \
const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, \
T *const input_grad_data
template <typename T>
__device__ __inline__ void KernelDepthwiseConvInputGrad(
const T* const output_grad_data, const T* const filter_data,
const int batch_size, const int output_channels, const int output_height,
const int output_width, const int input_channels, const int input_height,
const int input_width, const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* const input_grad_data) {
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
const int batch = blockIdx.y;
......@@ -184,15 +250,67 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
}
}
template <typename T, int c_filter_multiplier, int c_stride>
template <typename T, int c_filter, int c_filter_multiplier>
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
T r_weight[kWeghtSize];
const int batch = blockIdx.y;
const int c_in = blockIdx.x;
for (int c_i = 0; c_i < filter_multiplier; c_i++) {
int c_out = c_in * filter_multiplier + c_i;
const T* weight = filter_data + c_out * c_filter * c_filter;
for (int i = 0; i < c_filter * c_filter; i++)
r_weight[i + c_i * c_filter * c_filter] =
weight[c_filter * c_filter - i - 1];
}
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
const int batch = blockIdx.y;
const int c_in = blockIdx.x;
int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
T value = 0;
for (int c_i = 0; c_i < filter_multiplier; c_i++) {
int c_out = c_in * filter_multiplier + c_i;
for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
h_out += dilate_height, h_f++) {
for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
w_out += dilate_width, w_f++) {
int s_h_out = h_out / stride_height;
int s_w_out = w_out / stride_width;
if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
const int output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
value +=
output_grad_data[output_grad_offset] *
r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
}
}
}
}
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
input_grad_data[index] = value;
}
}
}
template <typename T, int c_filter_multiplier, int c_stride, int c_filter>
__global__ void KernelDepthwiseConvInputGradSp(
const T* const output_grad_data, const T* const filter_data,
const int batch_size, const int output_channels, const int output_height,
const int output_width, const int input_channels, const int input_height,
const int input_width, const int filter_multiplier, const int filter_height,
const int filter_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* const input_grad_data) {
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
if (c_filter_multiplier == 0)
KernelDepthwiseConvInputGrad<T>(
output_grad_data, filter_data, batch_size, output_channels,
......@@ -200,13 +318,20 @@ __global__ void KernelDepthwiseConvInputGradSp(
filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height,
dilate_width, input_grad_data);
else
else if (c_filter == -1)
KernelDepthwiseConvInputGrad<T>(
output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
input_grad_data);
else
KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier>(
output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width,
input_grad_data);
}
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
......@@ -325,12 +450,14 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
dim3 threads(std::min(output_width, thread), blocks, 1);
dim3 grid(output_channels, batch_size, 1);
int filter_multiplier = output_channels / input_channels;
#define check_case(c_filter_multiplier, c_stride) \
#define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride) { \
KernelDepthwiseConvSp<T, c_filter_multiplier, \
c_stride><<<grid, threads, 0, context.stream()>>>( \
stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
KernelDepthwiseConvSp<T, c_filter_multiplier, c_stride, \
c_filter><<<grid, threads, 0, context.stream()>>>( \
input_data, filter_data, batch_size, output_channels, output_height, \
output_width, input_channels, input_height, input_width, \
filter_multiplier, ksize_height, ksize_width, stride_height, \
......@@ -338,11 +465,17 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
dilate_width, output_data); \
return; \
}
check_case(1, 1);
check_case(1, 2);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
check_case(0, 0);
check_case(1, 1, 3);
check_case(1, 1, 5);
check_case(1, 1, -1);
check_case(1, 2, 3);
check_case(1, 2, 5);
check_case(1, 2, -1);
check_case(0, 0, 3);
check_case(0, 0, 5);
check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
#undef check_case
}
};
......@@ -384,13 +517,15 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
dim3 grid(input_channels, batch_size, 1);
int filter_multiplier = output_channels / input_channels;
#define check_case(c_filter_multiplier, c_stride) \
#define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \
filter_multiplier == c_filter_multiplier && \
stride_height == stride_width && stride_height == c_stride) { \
stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
KernelDepthwiseConvInputGradSp< \
T, c_filter_multiplier, \
c_stride><<<grid, threads, 0, context.stream()>>>( \
T, c_filter_multiplier, c_stride, \
c_filter><<<grid, threads, 0, context.stream()>>>( \
output_grad_data, filter_data, batch_size, output_channels, \
output_height, output_width, input_channels, input_height, \
input_width, filter_multiplier, ksize_height, ksize_width, \
......@@ -398,11 +533,21 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
dilate_height, dilate_width, input_grad_data); \
return; \
}
check_case(1, 1);
check_case(1, 2);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
check_case(0, 0);
check_case(1, 1, 3);
check_case(1, 1, 5);
check_case(1, 1, -1);
check_case(1, 2, 3);
check_case(1, 2, 5);
check_case(1, 2, -1);
check_case(2, 1, 3);
check_case(2, 1, 5);
check_case(2, 1, -1);
check_case(2, 2, 3);
check_case(2, 2, 5);
check_case(2, 2, -1);
check_case(0, 0, -1);
// NOTE(liangdun): 0,0 for other case
// add other case if needed, e.g. check_case(2^n,1)
#undef check_case
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册