diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index fae002e4340edee0a4f14932db5e35a0beec840e..d23f6e435c33325f2c019f8959c4adff74e7491c 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -53,22 +53,28 @@ __kernel void conv_2d(KERNEL_ERROR_PARAMS int in_width1 = in_width0 + in_width_stride; int in_width2 = in_width1 + in_width_stride; int in_width3 = in_width2 + in_width_stride; - const int height_idx = mad24((out_hb % out_height), stride, -padding_top); + const int height_start = mad24((out_hb % out_height), stride, -padding_top); + int in_height_gap = select( + 0, + (-height_start + dilation_h - 1) / dilation_h, + height_start < 0); + int in_height_start = mad24(in_height_gap, dilation_h, height_start); + int in_height_end = min(mad24(filter_height, dilation_h, height_start), + in_height); const int batch_idx = mul24((out_hb / out_height), in_height); const int filter_hw = mul24(filter_width, filter_height); + const int filter_y_idx_start = mul24(out_ch_blk, filter_hw) + + mul24(in_height_gap, filter_width); DATA_TYPE4 in0, in1, in2, in3; DATA_TYPE4 weights0, weights1, weights2, weights3; for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { const int in_idx = mul24(in_ch_blk, in_width); int filter_x_idx = in_ch_blk << 2; - int filter_y_idx = mul24(out_ch_blk, filter_hw); - for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) { - int in_hb_value = height_idx + mul24(hb_idx, dilation_h); - in_hb_value = select(in_hb_value + batch_idx, - -1, - (in_hb_value < 0 || in_hb_value >= in_height)); + int filter_y_idx = filter_y_idx_start; + for (int hb_idx = in_height_start; hb_idx < in_height_end; hb_idx += dilation_h) { + int in_hb_value = hb_idx + batch_idx; #pragma unroll for (short width_idx = 0; width_idx < filter_width; ++width_idx) { diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index 1389af6654172009bccea78ac60c58b8b50a771c..13e4ccb347cc4a42c67e7f8dc7e740cf336af0d9 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -54,21 +54,26 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS int in_width2 = in_width1 + in_width_stride; int in_width3 = in_width2 + in_width_stride; int in_width4 = in_width3 + in_width_stride; - const int height_idx = mad24((out_hb % out_height), stride, -padding_top); + const int height_start = mad24((out_hb % out_height), stride, -padding_top); + int in_height_gap = select( + 0, + (-height_start + dilation_h - 1) / dilation_h, + height_start < 0); + int in_height_start = mad24(in_height_gap, dilation_h, height_start); + int in_height_end = min(mad24(3, dilation_h, height_start), + in_height); const int batch_idx = mul24((out_hb / out_height), in_height); + const int filter_y_idx_start = mul24(out_ch_blk, 9) + mul24(in_height_gap, 3); DATA_TYPE4 in0, in1, in2, in3, in4; DATA_TYPE4 weights0, weights1, weights2, weights3; for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { const int in_idx = mul24(in_ch_blk, in_width); int filter_x_idx = in_ch_blk << 2; - int filter_y_idx = mul24(out_ch_blk, 9); - int in_hb_idx = height_idx; - for (short hb_idx = 0; hb_idx < 3; ++hb_idx) { - int in_hb_value = select(in_hb_idx + batch_idx, - -1, - (in_hb_idx < 0 || in_hb_idx >= in_height)); + int filter_y_idx = filter_y_idx_start; + for (int hb_idx = in_height_start; hb_idx < in_height_end; hb_idx += dilation_h) { + int in_hb_value = hb_idx + batch_idx; int in_width_idx = 0; for (short width_idx = 0; width_idx < 3; ++width_idx) { int in_width_value; @@ -122,7 +127,6 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS in_width_idx += dilation_w; filter_y_idx += 1; } - in_hb_idx += dilation_h; } }