提交 e82ebee8 编写于 作者: X xiebaiyuan 提交者: GitHub

optimise conv 1x1 ,test=develop (#2248)

上级 8006f5e6
...@@ -55,7 +55,7 @@ __kernel void conv_3x3(__private const int global_size_dim0, ...@@ -55,7 +55,7 @@ __kernel void conv_3x3(__private const int global_size_dim0,
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
if (out_c >= global_size_dim0 || if (out_c >= global_size_dim0 ||
...@@ -956,7 +956,7 @@ __kernel void conv_1x1(__private const int global_size_dim0, ...@@ -956,7 +956,7 @@ __kernel void conv_1x1(__private const int global_size_dim0,
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP | CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST; CLK_FILTER_NEAREST;
...@@ -1035,7 +1035,7 @@ __kernel void conv_1x1_spl( ...@@ -1035,7 +1035,7 @@ __kernel void conv_1x1_spl(
int out_w1 = out_w + global_size_dim1; int out_w1 = out_w + global_size_dim1;
int out_w2 = out_w + global_size_dim1 * 2; int out_w2 = out_w + global_size_dim1 * 2;
int out_w3 = out_w + global_size_dim1 * 3; int out_w3 = out_w + global_size_dim1 * 3;
int outpos_main = mul24(out_c , old_w); int outpos_main = mul24(out_c , old_w);
int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh);
int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh);
...@@ -1093,137 +1093,96 @@ __kernel void conv_1x1_spl( ...@@ -1093,137 +1093,96 @@ __kernel void conv_1x1_spl(
half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2)); half4 weight2 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 2));
half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3)); half4 weight3 = read_imageh(filter, sampler, (int2)(out_c, i * 4 + 3));
if ((max_w_bound - pos_in.x-1) < input_width && (max_w_bound - pos_in.x-1)>=0 ){ int bound_gap = max_w_bound - pos_in.x - 1;
if (bound_gap < input_width && bound_gap >= 0){
if (burndary_index==0){ if (burndary_index==0){
output0 = mad(input0.x, weight0, output0); // do nothing
output0 = mad(input0.y, weight1, output0);
output0 = mad(input0.z, weight2, output0);
output0 = mad(input0.w, weight3, output0);
} else if (burndary_index==1){ } else if (burndary_index==1){
output0 = mad(input0.x, weight0, output0); input0.w = 0.0f;
output0 = mad(input0.y, weight1, output0);
output0 = mad(input0.z, weight2, output0);
output0 = mad(0.0f, weight3, output0);
} else if (burndary_index==2){ } else if (burndary_index==2){
output0 = mad(input0.x, weight0, output0); input0.z = 0.0f;
output0 = mad(input0.y, weight1, output0); input0.w = 0.0f;
output0 = mad(0.0f, weight2, output0);
output0 = mad(0.0f, weight3, output0);
} else if (burndary_index==3){ } else if (burndary_index==3){
output0 = mad(input0.x, weight0, output0); input0.y = 0.0f;
output0 = mad(0.0f, weight1, output0); input0.z = 0.0f;
output0 = mad(0.0f, weight2, output0); input0.w = 0.0f;
output0 = mad(0.0f, weight3, output0);
} }
}else {
output0 = mad(input0.x, weight0, output0);
output0 = mad(input0.y, weight1, output0);
output0 = mad(input0.z, weight2, output0);
output0 = mad(input0.w, weight3, output0);
} }
output0 = mad(input0.x, weight0, output0);
output0 = mad(input0.y, weight1, output0);
output0 = mad(input0.z, weight2, output0);
output0 = mad(input0.w, weight3, output0);
// -------------1-------------- // -------------1--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, in_pos_in_one_block1.y); pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, in_pos_in_one_block1.y);
half4 input1 = read_imageh(input_image, sampler, pos_in); half4 input1 = read_imageh(input_image, sampler, pos_in);
if (abs(max_w_bound - pos_in.x) < input_width){ bound_gap = max_w_bound - pos_in.x - 1;
if (bound_gap < input_width && bound_gap >= 0){
if (burndary_index==0){ if (burndary_index==0){
output1 = mad(input1.x, weight0, output1); // do nothing
output1 = mad(input1.y, weight1, output1);
output1 = mad(input1.z, weight2, output1);
output1 = mad(input1.w, weight3, output1);
} else if (burndary_index==1){ } else if (burndary_index==1){
output1 = mad(input1.x, weight0, output1); input1.w = 0.0f;
output1 = mad(input1.y, weight1, output1);
output1 = mad(input1.z, weight2, output1);
output1 = mad(0.0f, weight3, output1);
} else if (burndary_index==2){ } else if (burndary_index==2){
output1 = mad(input1.x, weight0, output1); input1.z = 0.0f;
output1 = mad(input1.y, weight1, output1); input1.w = 0.0f;
output1 = mad(0.0f, weight2, output1);
output1 = mad(0.0f, weight3, output1);
} else if (burndary_index==3){ } else if (burndary_index==3){
output1 = mad(input1.x, weight0, output1); input1.y = 0.0f;
output1 = mad(0.0f, weight1, output1); input1.z = 0.0f;
output1 = mad(0.0f, weight2, output1); input1.w = 0.0f;
output1 = mad(0.0f, weight3, output1);
} }
}else {
output1 = mad(input1.x, weight0, output1);
output1 = mad(input1.y, weight1, output1);
output1 = mad(input1.z, weight2, output1);
output1 = mad(input1.w, weight3, output1);
} }
output1 = mad(input1.x, weight0, output1);
output1 = mad(input1.y, weight1, output1);
output1 = mad(input1.z, weight2, output1);
output1 = mad(input1.w, weight3, output1);
// -------------2-------------- // -------------2--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, in_pos_in_one_block2.y); pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, in_pos_in_one_block2.y);
half4 input2 = read_imageh(input_image, sampler, pos_in); half4 input2 = read_imageh(input_image, sampler, pos_in);
if (abs(max_w_bound - pos_in.x) < input_width){ bound_gap = max_w_bound - pos_in.x - 1;
if (bound_gap < input_width && bound_gap >= 0){
if (burndary_index==0){ if (burndary_index==0){
output2 = mad(input2.x, weight0, output2); // do nothing
output2 = mad(input2.y, weight1, output2);
output2 = mad(input2.z, weight2, output2);
output2 = mad(input2.w, weight3, output2);
} else if (burndary_index==1){ } else if (burndary_index==1){
output2 = mad(input2.x, weight0, output2); input2.w = 0.0f;
output2 = mad(input2.y, weight1, output2);
output2 = mad(input2.z, weight2, output2);
output2 = mad(0.0f, weight3, output2);
} else if (burndary_index==2){ } else if (burndary_index==2){
output2 = mad(input2.x, weight0, output2); input2.z = 0.0f;
output2 = mad(input2.y, weight1, output2); input2.w = 0.0f;
output2 = mad(0.0f, weight2, output2);
output2 = mad(0.0f, weight3, output2);
} else if (burndary_index==3){ } else if (burndary_index==3){
output2 = mad(input2.x, weight0, output2); input2.y = 0.0f;
output2 = mad(0.0f, weight1, output2); input2.z = 0.0f;
output2 = mad(0.0f, weight2, output2); input2.w = 0.0f;
output2 = mad(0.0f, weight3, output2);
} }
}else {
output2 = mad(input2.x, weight0, output2);
output2 = mad(input2.y, weight1, output2);
output2 = mad(input2.z, weight2, output2);
output2 = mad(input2.w, weight3, output2);
} }
output2 = mad(input2.x, weight0, output2);
output2 = mad(input2.y, weight1, output2);
output2 = mad(input2.z, weight2, output2);
output2 = mad(input2.w, weight3, output2);
// -------------3-------------- // -------------3--------------
pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, in_pos_in_one_block3.y); pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, in_pos_in_one_block3.y);
half4 input3 = read_imageh(input_image, sampler, pos_in); half4 input3 = read_imageh(input_image, sampler, pos_in);
bound_gap = max_w_bound - pos_in.x - 1;
if (abs(max_w_bound - pos_in.x) < input_width){ if (bound_gap < input_width && bound_gap >= 0){
if (burndary_index==0){ if (burndary_index==0){
output3 = mad(input3.x, weight0, output3); // do nothing
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(input3.w, weight3, output3);
} else if (burndary_index==1){ } else if (burndary_index==1){
output3 = mad(input3.x, weight0, output3); input3.w = 0.0f;
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(0.0f, weight3, output3);
} else if (burndary_index==2){ } else if (burndary_index==2){
output3 = mad(input3.x, weight0, output3); input3.z = 0.0f;
output3 = mad(input3.y, weight1, output3); input3.w = 0.0f;
output3 = mad(0.0f, weight2, output3);
output3 = mad(0.0f, weight3, output3);
} else if (burndary_index==3){ } else if (burndary_index==3){
output3 = mad(input3.x, weight0, output3); input3.y = 0.0f;
output3 = mad(0.0f, weight1, output3); input3.z = 0.0f;
output3 = mad(0.0f, weight2, output3); input3.w = 0.0f;
output3 = mad(0.0f, weight3, output3);
} }
}else {
output3 = mad(input3.x, weight0, output3);
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(input3.w, weight3, output3);
} }
output3 = mad(input3.x, weight0, output3);
output3 = mad(input3.y, weight1, output3);
output3 = mad(input3.z, weight2, output3);
output3 = mad(input3.w, weight3, output3);
} }
#ifdef BATCH_NORM #ifdef BATCH_NORM
...@@ -1292,7 +1251,7 @@ __kernel void conv_7x7(__private const int global_size_dim0, ...@@ -1292,7 +1251,7 @@ __kernel void conv_7x7(__private const int global_size_dim0,
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
if (out_c >= global_size_dim0 || if (out_c >= global_size_dim0 ||
...@@ -1764,7 +1723,7 @@ __kernel void conv_5x5(__private const int global_size_dim0, ...@@ -1764,7 +1723,7 @@ __kernel void conv_5x5(__private const int global_size_dim0,
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
if (out_c >= global_size_dim0 || if (out_c >= global_size_dim0 ||
...@@ -1884,7 +1843,7 @@ __kernel void convBNAdd_3x3(__private const int global_size_dim0, ...@@ -1884,7 +1843,7 @@ __kernel void convBNAdd_3x3(__private const int global_size_dim0,
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
if (out_c >= global_size_dim0 || if (out_c >= global_size_dim0 ||
...@@ -2193,7 +2152,7 @@ __kernel void convBNAdd_1x1(__private const int global_size_dim0, ...@@ -2193,7 +2152,7 @@ __kernel void convBNAdd_1x1(__private const int global_size_dim0,
const int out_nh = get_global_id(2); const int out_nh = get_global_id(2);
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP | CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST; CLK_FILTER_NEAREST;
...@@ -2301,7 +2260,7 @@ __kernel void convBNAdd_1x1_spl( ...@@ -2301,7 +2260,7 @@ __kernel void convBNAdd_1x1_spl(
int2 in_pos_in_one_block3 = int2 in_pos_in_one_block3 =
ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset);
half4 output0 = 0.0f; half4 output0 = 0.0f;
half4 output1 = 0.0f; half4 output1 = 0.0f;
half4 output2 = 0.0f; half4 output2 = 0.0f;
...@@ -2393,7 +2352,7 @@ __kernel void convBNAdd_1x1_spl( ...@@ -2393,7 +2352,7 @@ __kernel void convBNAdd_1x1_spl(
output2 += read_imageh(bias, sampler, output_pos2); output2 += read_imageh(bias, sampler, output_pos2);
output3 += read_imageh(bias, sampler, output_pos3); output3 += read_imageh(bias, sampler, output_pos3);
#endif #endif
#ifdef RELU #ifdef RELU
output0 = activation(output0); output0 = activation(output0);
output1 = activation(output1); output1 = activation(output1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册