From ba4ca883184034ce6b3e243c8b9dc54e5d8ee498 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 15 Nov 2017 16:07:27 +0800 Subject: [PATCH] Adjust the postion of judge clauses of conv opencl kernel. --- mace/kernels/opencl/cl/conv_2d_1x1.cl | 85 +++++++++++++++------------ mace/kernels/opencl/cl/conv_2d_3x3.cl | 15 +++-- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index 020e6bdc..8025074f 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -24,25 +24,14 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */ } } -#define vec_conv_2d_1x1_s1(out_size) \ -do { \ +#define vec_conv_2d_1x1_s1 \ float4 in0 = vload4(0, input_ptr); \ float4 in1 = vload4(0, input_ptr + in_pixel); \ float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \ - float4 in3 = vload4(0, input_ptr + 3 * in_pixel); \ - for (int oc = 0; oc < out_size; ++oc) { \ - float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ - float4 out = vload4(0, output_ptr + oc * out_pixel); \ - out += in0 * weights.x; \ - out += in1 * weights.y; \ - out += in2 * weights.z; \ - out += in3 * weights.w; \ - vstore4(out, 0, output_ptr + oc * out_pixel); \ - } \ -} while(0) + float4 in3 = vload4(0, input_ptr + 3 * in_pixel); -#define vec_conv_2d_1x1_s2(out_size) \ -do { \ + +#define vec_conv_2d_1x1_s2 \ float4 in00 = vload4(0, input_ptr); \ float3 in01 = vload3(0, input_ptr + 4); \ float4 in10 = vload4(0, input_ptr + in_pixel); \ @@ -54,8 +43,11 @@ do { \ float4 in0 = (float4)(in00.s02, in01.s02); \ float4 in1 = (float4)(in10.s02, in11.s02); \ float4 in2 = (float4)(in20.s02, in21.s02); \ - float4 in3 = (float4)(in30.s02, in31.s02); \ - for (int oc = 0; oc < out_size; ++oc) { \ + float4 in3 = (float4)(in30.s02, in31.s02); + + +#define vec_conv_2d_1x1_compute_loop \ + for (int oc = 0; oc < 4; ++oc) { \ float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ float4 out = vload4(0, output_ptr + oc * out_pixel); \ out += in0 * weights.x; \ @@ -63,10 +55,16 @@ do { \ out += in2 * weights.z; \ out += in3 * weights.w; \ vstore4(out, 0, output_ptr + oc * out_pixel); \ - } \ -} while(0) - + } +#define vec_conv_2d_1x1_compute \ + float4 weights = vload4(0, filter_ptr); \ + float4 out = vload4(0, output_ptr); \ + out += in0 * weights.x; \ + out += in1 * weights.y; \ + out += in2 * weights.z; \ + out += in3 * weights.w; \ + vstore4(out, 0, output_ptr); __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ __global const float *filter, /* o, i, kh, kw */ @@ -115,25 +113,38 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ int in_chan = 0; if (pixel_len == 4) { - for (; in_chan + 3 < in_chan_num; in_chan += 4) { - const float *input_ptr = input_base + in_chan * in_pixel; - int out_chan = out_chan_begin; - for (; out_chan + 3 < out_chan_end; out_chan += 4) { - const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * out_pixel; - if (stride == 1) { - vec_conv_2d_1x1_s1(4); - } else if (stride == 2) { - vec_conv_2d_1x1_s2(4); + if (stride == 1) { + for (; in_chan + 3 < in_chan_num; in_chan += 4) { + const float *input_ptr = input_base + in_chan * in_pixel; + int out_chan = out_chan_begin; + for (; out_chan + 3 < out_chan_end; out_chan += 4) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s1; + vec_conv_2d_1x1_compute_loop; + } + for (; out_chan < out_chan_end; ++out_chan) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s1; + vec_conv_2d_1x1_compute; } } - for (; out_chan < out_chan_end; ++out_chan) { - const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * out_pixel; - if (stride == 1) { - vec_conv_2d_1x1_s1(1); - } else if (stride == 2) { - vec_conv_2d_1x1_s2(1); + } else if (stride == 2) { + for (; in_chan + 3 < in_chan_num; in_chan += 4) { + const float *input_ptr = input_base + in_chan * in_pixel; + int out_chan = out_chan_begin; + for (; out_chan + 3 < out_chan_end; out_chan += 4) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s2; + vec_conv_2d_1x1_compute_loop; + } + for (; out_chan < out_chan_end; ++out_chan) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s2; + vec_conv_2d_1x1_compute; } } } diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index b3f7735d..317daaaf 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input, if (pixels == 4) { float4 res = bias == NULL ? 0 : (float4)bias[i]; - for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { - const float* input_ptr = input_base + in_chan_idx * in_pixel; - const float* filter_ptr = filter_base + in_chan_idx * 9; - if (stride_w == 1) { + + if (stride_w == 1) { + for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { + const float* input_ptr = input_base + in_chan_idx * in_pixel; + const float* filter_ptr = filter_base + in_chan_idx * 9; res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3); - } else { + } + } else { + for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { + const float* input_ptr = input_base + in_chan_idx * in_pixel; + const float* filter_ptr = filter_base + in_chan_idx * 9; res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3); -- GitLab