提交 c8fe6c28 编写于 作者: 刘托

Merge branch 'gpu-conv-opt' into 'master'

Optimize convolution opencl kernel remove unused select.

See merge request !551
...@@ -53,22 +53,28 @@ __kernel void conv_2d(KERNEL_ERROR_PARAMS ...@@ -53,22 +53,28 @@ __kernel void conv_2d(KERNEL_ERROR_PARAMS
int in_width1 = in_width0 + in_width_stride; int in_width1 = in_width0 + in_width_stride;
int in_width2 = in_width1 + in_width_stride; int in_width2 = in_width1 + in_width_stride;
int in_width3 = in_width2 + in_width_stride; int in_width3 = in_width2 + in_width_stride;
const int height_idx = mad24((out_hb % out_height), stride, -padding_top); const int height_start = mad24((out_hb % out_height), stride, -padding_top);
int in_height_gap = select(
0,
(-height_start + dilation_h - 1) / dilation_h,
height_start < 0);
int in_height_start = mad24(in_height_gap, dilation_h, height_start);
int in_height_end = min(mad24(filter_height, dilation_h, height_start),
in_height);
const int batch_idx = mul24((out_hb / out_height), in_height); const int batch_idx = mul24((out_hb / out_height), in_height);
const int filter_hw = mul24(filter_width, filter_height); const int filter_hw = mul24(filter_width, filter_height);
const int filter_y_idx_start = mul24(out_ch_blk, filter_hw)
+ mul24(in_height_gap, filter_width);
DATA_TYPE4 in0, in1, in2, in3; DATA_TYPE4 in0, in1, in2, in3;
DATA_TYPE4 weights0, weights1, weights2, weights3; DATA_TYPE4 weights0, weights1, weights2, weights3;
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
const int in_idx = mul24(in_ch_blk, in_width); const int in_idx = mul24(in_ch_blk, in_width);
int filter_x_idx = in_ch_blk << 2; int filter_x_idx = in_ch_blk << 2;
int filter_y_idx = mul24(out_ch_blk, filter_hw); int filter_y_idx = filter_y_idx_start;
for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) { for (int hb_idx = in_height_start; hb_idx < in_height_end; hb_idx += dilation_h) {
int in_hb_value = height_idx + mul24(hb_idx, dilation_h); int in_hb_value = hb_idx + batch_idx;
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
#pragma unroll #pragma unroll
for (short width_idx = 0; width_idx < filter_width; ++width_idx) { for (short width_idx = 0; width_idx < filter_width; ++width_idx) {
......
...@@ -54,21 +54,26 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS ...@@ -54,21 +54,26 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS
int in_width2 = in_width1 + in_width_stride; int in_width2 = in_width1 + in_width_stride;
int in_width3 = in_width2 + in_width_stride; int in_width3 = in_width2 + in_width_stride;
int in_width4 = in_width3 + in_width_stride; int in_width4 = in_width3 + in_width_stride;
const int height_idx = mad24((out_hb % out_height), stride, -padding_top); const int height_start = mad24((out_hb % out_height), stride, -padding_top);
int in_height_gap = select(
0,
(-height_start + dilation_h - 1) / dilation_h,
height_start < 0);
int in_height_start = mad24(in_height_gap, dilation_h, height_start);
int in_height_end = min(mad24(3, dilation_h, height_start),
in_height);
const int batch_idx = mul24((out_hb / out_height), in_height); const int batch_idx = mul24((out_hb / out_height), in_height);
const int filter_y_idx_start = mul24(out_ch_blk, 9) + mul24(in_height_gap, 3);
DATA_TYPE4 in0, in1, in2, in3, in4; DATA_TYPE4 in0, in1, in2, in3, in4;
DATA_TYPE4 weights0, weights1, weights2, weights3; DATA_TYPE4 weights0, weights1, weights2, weights3;
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
const int in_idx = mul24(in_ch_blk, in_width); const int in_idx = mul24(in_ch_blk, in_width);
int filter_x_idx = in_ch_blk << 2; int filter_x_idx = in_ch_blk << 2;
int filter_y_idx = mul24(out_ch_blk, 9); int filter_y_idx = filter_y_idx_start;
int in_hb_idx = height_idx; for (int hb_idx = in_height_start; hb_idx < in_height_end; hb_idx += dilation_h) {
for (short hb_idx = 0; hb_idx < 3; ++hb_idx) { int in_hb_value = hb_idx + batch_idx;
int in_hb_value = select(in_hb_idx + batch_idx,
-1,
(in_hb_idx < 0 || in_hb_idx >= in_height));
int in_width_idx = 0; int in_width_idx = 0;
for (short width_idx = 0; width_idx < 3; ++width_idx) { for (short width_idx = 0; width_idx < 3; ++width_idx) {
int in_width_value; int in_width_value;
...@@ -122,7 +127,6 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS ...@@ -122,7 +127,6 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS
in_width_idx += dilation_w; in_width_idx += dilation_w;
filter_y_idx += 1; filter_y_idx += 1;
} }
in_hb_idx += dilation_h;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册