提交 ba4ca883 编写于 作者: L liuqi

Adjust the postion of judge clauses of conv opencl kernel.

上级 5bbd271e
...@@ -24,25 +24,14 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */ ...@@ -24,25 +24,14 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */
} }
} }
#define vec_conv_2d_1x1_s1(out_size) \ #define vec_conv_2d_1x1_s1 \
do { \
float4 in0 = vload4(0, input_ptr); \ float4 in0 = vload4(0, input_ptr); \
float4 in1 = vload4(0, input_ptr + in_pixel); \ float4 in1 = vload4(0, input_ptr + in_pixel); \
float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \ float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \
float4 in3 = vload4(0, input_ptr + 3 * in_pixel); \ float4 in3 = vload4(0, input_ptr + 3 * in_pixel);
for (int oc = 0; oc < out_size; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
} \
} while(0)
#define vec_conv_2d_1x1_s2(out_size) \
do { \ #define vec_conv_2d_1x1_s2 \
float4 in00 = vload4(0, input_ptr); \ float4 in00 = vload4(0, input_ptr); \
float3 in01 = vload3(0, input_ptr + 4); \ float3 in01 = vload3(0, input_ptr + 4); \
float4 in10 = vload4(0, input_ptr + in_pixel); \ float4 in10 = vload4(0, input_ptr + in_pixel); \
...@@ -54,8 +43,11 @@ do { \ ...@@ -54,8 +43,11 @@ do { \
float4 in0 = (float4)(in00.s02, in01.s02); \ float4 in0 = (float4)(in00.s02, in01.s02); \
float4 in1 = (float4)(in10.s02, in11.s02); \ float4 in1 = (float4)(in10.s02, in11.s02); \
float4 in2 = (float4)(in20.s02, in21.s02); \ float4 in2 = (float4)(in20.s02, in21.s02); \
float4 in3 = (float4)(in30.s02, in31.s02); \ float4 in3 = (float4)(in30.s02, in31.s02);
for (int oc = 0; oc < out_size; ++oc) { \
#define vec_conv_2d_1x1_compute_loop \
for (int oc = 0; oc < 4; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \ float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \ out += in0 * weights.x; \
...@@ -63,10 +55,16 @@ do { \ ...@@ -63,10 +55,16 @@ do { \
out += in2 * weights.z; \ out += in2 * weights.z; \
out += in3 * weights.w; \ out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \ vstore4(out, 0, output_ptr + oc * out_pixel); \
} \ }
} while(0)
#define vec_conv_2d_1x1_compute \
float4 weights = vload4(0, filter_ptr); \
float4 out = vload4(0, output_ptr); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr);
__kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
__global const float *filter, /* o, i, kh, kw */ __global const float *filter, /* o, i, kh, kw */
...@@ -115,25 +113,38 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ ...@@ -115,25 +113,38 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int in_chan = 0; int in_chan = 0;
if (pixel_len == 4) { if (pixel_len == 4) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) { if (stride == 1) {
const float *input_ptr = input_base + in_chan * in_pixel; for (; in_chan + 3 < in_chan_num; in_chan += 4) {
int out_chan = out_chan_begin; const float *input_ptr = input_base + in_chan * in_pixel;
for (; out_chan + 3 < out_chan_end; out_chan += 4) { int out_chan = out_chan_begin;
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; for (; out_chan + 3 < out_chan_end; out_chan += 4) {
float *output_ptr = output_base + out_chan * out_pixel; const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
if (stride == 1) { float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1(4); vec_conv_2d_1x1_s1;
} else if (stride == 2) { vec_conv_2d_1x1_compute_loop;
vec_conv_2d_1x1_s2(4); }
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1;
vec_conv_2d_1x1_compute;
} }
} }
for (; out_chan < out_chan_end; ++out_chan) { } else if (stride == 2) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; for (; in_chan + 3 < in_chan_num; in_chan += 4) {
float *output_ptr = output_base + out_chan * out_pixel; const float *input_ptr = input_base + in_chan * in_pixel;
if (stride == 1) { int out_chan = out_chan_begin;
vec_conv_2d_1x1_s1(1); for (; out_chan + 3 < out_chan_end; out_chan += 4) {
} else if (stride == 2) { const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
vec_conv_2d_1x1_s2(1); float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute_loop;
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute;
} }
} }
} }
......
...@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input, ...@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input,
if (pixels == 4) { if (pixels == 4) {
float4 res = bias == NULL ? 0 : (float4)bias[i]; float4 res = bias == NULL ? 0 : (float4)bias[i];
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel; if (stride_w == 1) {
const float* filter_ptr = filter_base + in_chan_idx * 9; for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
if (stride_w == 1) { const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3); res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
} else { }
} else {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3); res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册