提交 ba4ca883 编写于 作者: L liuqi

Adjust the postion of judge clauses of conv opencl kernel.

上级 5bbd271e
......@@ -24,25 +24,14 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */
}
}
#define vec_conv_2d_1x1_s1(out_size) \
do { \
#define vec_conv_2d_1x1_s1 \
float4 in0 = vload4(0, input_ptr); \
float4 in1 = vload4(0, input_ptr + in_pixel); \
float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \
float4 in3 = vload4(0, input_ptr + 3 * in_pixel); \
for (int oc = 0; oc < out_size; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
} \
} while(0)
float4 in3 = vload4(0, input_ptr + 3 * in_pixel);
#define vec_conv_2d_1x1_s2(out_size) \
do { \
#define vec_conv_2d_1x1_s2 \
float4 in00 = vload4(0, input_ptr); \
float3 in01 = vload3(0, input_ptr + 4); \
float4 in10 = vload4(0, input_ptr + in_pixel); \
......@@ -54,8 +43,11 @@ do { \
float4 in0 = (float4)(in00.s02, in01.s02); \
float4 in1 = (float4)(in10.s02, in11.s02); \
float4 in2 = (float4)(in20.s02, in21.s02); \
float4 in3 = (float4)(in30.s02, in31.s02); \
for (int oc = 0; oc < out_size; ++oc) { \
float4 in3 = (float4)(in30.s02, in31.s02);
#define vec_conv_2d_1x1_compute_loop \
for (int oc = 0; oc < 4; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
......@@ -63,10 +55,16 @@ do { \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
} \
} while(0)
}
#define vec_conv_2d_1x1_compute \
float4 weights = vload4(0, filter_ptr); \
float4 out = vload4(0, output_ptr); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr);
__kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
__global const float *filter, /* o, i, kh, kw */
......@@ -115,25 +113,38 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int in_chan = 0;
if (pixel_len == 4) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
if (stride == 1) {
vec_conv_2d_1x1_s1(4);
} else if (stride == 2) {
vec_conv_2d_1x1_s2(4);
if (stride == 1) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1;
vec_conv_2d_1x1_compute_loop;
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1;
vec_conv_2d_1x1_compute;
}
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
if (stride == 1) {
vec_conv_2d_1x1_s1(1);
} else if (stride == 2) {
vec_conv_2d_1x1_s2(1);
} else if (stride == 2) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute_loop;
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute;
}
}
}
......
......@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input,
if (pixels == 4) {
float4 res = bias == NULL ? 0 : (float4)bias[i];
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
if (stride_w == 1) {
if (stride_w == 1) {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
} else {
}
} else {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册