未验证 提交 1c20510e 编写于 作者: Y ysh329 提交者: GitHub

[BugFix][KERNEL][OPENCL] Fix conv3x3 group. test=develop (#4236)

* fix conv3x3 group. test=develop
* remove useless. test=develop
* optimize 3x3 group
上级 47d21d28
...@@ -30,13 +30,11 @@ __kernel void conv2d_3x3(__private const int global_size_dim0, ...@@ -30,13 +30,11 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
__private const int output_width, __private const int output_width,
__private const int output_height, __private const int output_height,
__private const int output_c, __private const int output_c,
__private const int filter_channel, __private const int filter_tensor_c,
__private const int filter_width, __private const int filter_width,
__private const int filter_height, __private const int filter_height,
__private const int group, __private const int group,
__private const int input_tensor_c __private const int input_tensor_c) {
) {
const int out_c = get_global_id(0); const int out_c = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
...@@ -72,408 +70,112 @@ __kernel void conv2d_3x3(__private const int global_size_dim0, ...@@ -72,408 +70,112 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
#else #else
CL_DTYPE4 output = (CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f); CL_DTYPE4 output = (CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f);
#endif #endif
CL_DTYPE4 zero_dtype4 = (CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f);
CL_DTYPE4 input[9]; // 3x3 region of input CL_DTYPE4 input0, input1, input2, input3, input4, input5, input6, input7, input8;
if (group == 1) {
for (int i = 0; i < input_c; ++i) { // each run for 3x3
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x,
in_pos_in_one_block.y);
input[0] = select(
READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y - dilation >= input_height)
<< 15));
input[1] =
select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x < 0 ||
in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y - dilation >= input_height)
<< 15));
input[2] = select(
READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x + dilation >= input_width ||
in_pos_in_one_block.y - dilation >= input_height)
<< 15));
input[3] =
select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x - dilation, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y >= input_height)
<< 15));
input[4] = select(
READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(pos_in.x, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y >= input_height)
<< 15));
input[5] =
select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x + dilation, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x + dilation >= input_width ||
in_pos_in_one_block.y >= input_height)
<< 15));
input[6] = select(
READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y + dilation >= input_height)
<< 15));
input[7] =
select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x < 0 ||
in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y + dilation >= input_height)
<< 15));
input[8] = select(
READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image,
sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f),
(ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x + dilation >= input_width ||
in_pos_in_one_block.y + dilation >= input_height)
<< 15));
if (i == input_c - 1) {
int c_shr = input_tensor_c % 4;
if (c_shr == 1) {
for (int k = 0; k < 9; k++) {
input[k].y = (half)0.f;
input[k].z = (half)0.f;
input[k].w = (half)0.f;
}
} else if (c_shr == 2) {
for (int k = 0; k < 9; k++) {
input[k].z = (half)0.f;
input[k].w = (half)0.f;
}
} else if (c_shr == 3) {
for (int k = 0; k < 9; k++) {
input[k].w = (half)0.f;
}
} else if (c_shr == 0) {
}
}
int j = 0;
int2 pos_of_weight;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
CL_DTYPE4 weight_x =
READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y += 3;
CL_DTYPE4 weight_y =
READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y += 3;
CL_DTYPE4 weight_z =
READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y += 3;
CL_DTYPE4 weight_w =
READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 1;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 2;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 3;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 4;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 5;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 6;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 7;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
j = 8;
pos_of_weight.x = i * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3;
weight_x = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.x += dot(input[j], weight_x);
pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3;
weight_y = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.y += dot(input[j], weight_y);
pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3;
weight_z = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.z += dot(input[j], weight_z);
pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3;
weight_w = READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
output.w += dot(input[j], weight_w);
}
} else { // group != 1
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int used_input_channel_num = int used_input_channel_num =
(out_c * 4 + i) / (output_c / group) * filter_channel; (out_c * 4 + i) / (output_c / group) * filter_tensor_c;
for (int f_c = 0; f_c < filter_channel; ++f_c) { for (int filter_tensor_c_idx = 0; filter_tensor_c_idx < filter_tensor_c; ++filter_tensor_c_idx) {
int input_c = used_input_channel_num + f_c; int input_c = used_input_channel_num + filter_tensor_c_idx;
int input_block = input_c / 4; int input_block = input_c / 4;
int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x, int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x,
in_pos_in_one_block.y); in_pos_in_one_block.y);
input[0] = select( input0 = select(
READ_IMG_TYPE(CL_DTYPE_CHAR, READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)), (int2)(pos_in.x - dilation, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x - dilation < 0 || (ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y - dilation >= input_height) in_pos_in_one_block.y - dilation >= input_height)
<< 15)); << 15));
input[1] = input1 =
select(READ_IMG_TYPE(CL_DTYPE_CHAR, select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x, pos_in.y - dilation)), (int2)(pos_in.x, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x < 0 || (ushort4)((in_pos_in_one_block.x < 0 ||
in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x >= input_width || in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y - dilation >= input_height) in_pos_in_one_block.y - dilation >= input_height)
<< 15)); << 15));
input[2] = select( input2 = select(
READ_IMG_TYPE(CL_DTYPE_CHAR, READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)), (int2)(pos_in.x + dilation, pos_in.y - dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x + dilation < 0 || (ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.y - dilation < 0 ||
in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.x + dilation >= input_width ||
in_pos_in_one_block.y - dilation >= input_height) in_pos_in_one_block.y - dilation >= input_height)
<< 15)); << 15));
input[3] =
input3 =
select(READ_IMG_TYPE(CL_DTYPE_CHAR, select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x - dilation, pos_in.y)), (int2)(pos_in.x - dilation, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x - dilation < 0 || (ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y < 0 || in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y >= input_height) in_pos_in_one_block.y >= input_height)
<< 15)); << 15));
input[4] = select(
input4 = select(
READ_IMG_TYPE(CL_DTYPE_CHAR, READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x, pos_in.y)), (int2)(pos_in.x, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x >= input_width || in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y >= input_height) in_pos_in_one_block.y >= input_height)
<< 15)); << 15));
input[5] = input5 =
select(READ_IMG_TYPE(CL_DTYPE_CHAR, select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x + dilation, pos_in.y)), (int2)(pos_in.x + dilation, pos_in.y)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x + dilation < 0 || (ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y < 0 || in_pos_in_one_block.y < 0 ||
in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.x + dilation >= input_width ||
in_pos_in_one_block.y >= input_height) in_pos_in_one_block.y >= input_height)
<< 15)); << 15));
input[6] = select( input6 = select(
READ_IMG_TYPE(CL_DTYPE_CHAR, READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)), (int2)(pos_in.x - dilation, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x - dilation < 0 || (ushort4)((in_pos_in_one_block.x - dilation < 0 ||
in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.x - dilation >= input_width ||
in_pos_in_one_block.y + dilation >= input_height) in_pos_in_one_block.y + dilation >= input_height)
<< 15)); << 15));
input[7] = input7 =
select(READ_IMG_TYPE(CL_DTYPE_CHAR, select(READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x, pos_in.y + dilation)), (int2)(pos_in.x, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x < 0 || (ushort4)((in_pos_in_one_block.x < 0 ||
in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x >= input_width || in_pos_in_one_block.x >= input_width ||
in_pos_in_one_block.y + dilation >= input_height) in_pos_in_one_block.y + dilation >= input_height)
<< 15)); << 15));
input[8] = select( input8 = select(
READ_IMG_TYPE(CL_DTYPE_CHAR, READ_IMG_TYPE(CL_DTYPE_CHAR,
input_image, input_image,
sampler, sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)), (int2)(pos_in.x + dilation, pos_in.y + dilation)),
(CL_DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f), zero_dtype4,
(ushort4)((in_pos_in_one_block.x + dilation < 0 || (ushort4)((in_pos_in_one_block.x + dilation < 0 ||
in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.y + dilation < 0 ||
in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.x + dilation >= input_width ||
...@@ -483,50 +185,75 @@ __kernel void conv2d_3x3(__private const int global_size_dim0, ...@@ -483,50 +185,75 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
CL_DTYPE tmp_out = 0; CL_DTYPE tmp_out = 0;
for (int j = 0; j < 9; j++) { for (int j = 0; j < 9; j++) {
int2 pos_of_weight; int2 pos_of_weight;
pos_of_weight.x = (f_c / 4) * 3 + j % 3; pos_of_weight.x = (filter_tensor_c_idx / 4) * 3 + j % 3;
pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3; pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3;
CL_DTYPE4 weight = CL_DTYPE4 weight =
READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight); READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, pos_of_weight);
int f_c_offset = f_c % 4; int filter_tensor_c_idx_offset = filter_tensor_c_idx % 4;
CL_DTYPE f_value; CL_DTYPE f_value = 0;
if (f_c_offset == 0) { f_value = (filter_tensor_c_idx_offset == 0) ? weight.x : f_value;
f_value = weight.x; f_value = (filter_tensor_c_idx_offset == 1) ? weight.y : f_value;
} else if (f_c_offset == 1) { f_value = (filter_tensor_c_idx_offset == 2) ? weight.z : f_value;
f_value = weight.y; f_value = (filter_tensor_c_idx_offset == 3) ? weight.w : f_value;
} else if (f_c_offset == 2) {
f_value = weight.z;
} else if (f_c_offset == 3) {
f_value = weight.w;
}
int input_c_offset = input_c % 4; int input_c_offset = input_c % 4;
CL_DTYPE input_value; CL_DTYPE input_value = 0;
if (input_c_offset == 0) { if (j == 0) {
input_value = input[j].x; input_value = (input_c_offset == 0) ? input0.x : input_value;
} else if (input_c_offset == 1) { input_value = (input_c_offset == 1) ? input0.y : input_value;
input_value = input[j].y; input_value = (input_c_offset == 2) ? input0.z : input_value;
} else if (input_c_offset == 2) { input_value = (input_c_offset == 3) ? input0.w : input_value;
input_value = input[j].z; } else if (j == 1) {
} else if (input_c_offset == 3) { input_value = (input_c_offset == 0) ? input1.x : input_value;
input_value = input[j].w; input_value = (input_c_offset == 1) ? input1.y : input_value;
} input_value = (input_c_offset == 2) ? input1.z : input_value;
tmp_out += f_value * input_value; input_value = (input_c_offset == 3) ? input1.w : input_value;
} } else if (j == 2) {
input_value = (input_c_offset == 0) ? input2.x : input_value;
input_value = (input_c_offset == 1) ? input2.y : input_value;
input_value = (input_c_offset == 2) ? input2.z : input_value;
input_value = (input_c_offset == 3) ? input2.w : input_value;
} else if (j == 3) {
input_value = (input_c_offset == 0) ? input3.x : input_value;
input_value = (input_c_offset == 1) ? input3.y : input_value;
input_value = (input_c_offset == 2) ? input3.z : input_value;
input_value = (input_c_offset == 3) ? input3.w : input_value;
} else if (j == 4) {
input_value = (input_c_offset == 0) ? input4.x : input_value;
input_value = (input_c_offset == 1) ? input4.y : input_value;
input_value = (input_c_offset == 2) ? input4.z : input_value;
input_value = (input_c_offset == 3) ? input4.w : input_value;
} else if (j == 5) {
input_value = (input_c_offset == 0) ? input5.x : input_value;
input_value = (input_c_offset == 1) ? input5.y : input_value;
input_value = (input_c_offset == 2) ? input5.z : input_value;
input_value = (input_c_offset == 3) ? input5.w : input_value;
} else if (j == 6) {
input_value = (input_c_offset == 0) ? input6.x : input_value;
input_value = (input_c_offset == 1) ? input6.y : input_value;
input_value = (input_c_offset == 2) ? input6.z : input_value;
input_value = (input_c_offset == 3) ? input6.w : input_value;
} else if (j == 7) {
input_value = (input_c_offset == 0) ? input7.x : input_value;
input_value = (input_c_offset == 1) ? input7.y : input_value;
input_value = (input_c_offset == 2) ? input7.z : input_value;
input_value = (input_c_offset == 3) ? input7.w : input_value;
} else if (j == 8) {
input_value = (input_c_offset == 0) ? input8.x : input_value;
input_value = (input_c_offset == 1) ? input8.y : input_value;
input_value = (input_c_offset == 2) ? input8.z : input_value;
input_value = (input_c_offset == 3) ? input8.w : input_value;
}
if (i == 0) { tmp_out += f_value * input_value;
output.x += tmp_out;
} else if (i == 1) {
output.y += tmp_out;
} else if (i == 2) {
output.z += tmp_out;
} else if (i == 3) {
output.w += tmp_out;
} }
output.x = (i == 0) ? output.x + tmp_out : output.x;
output.y = (i == 1) ? output.y + tmp_out : output.y;
output.z = (i == 2) ? output.z + tmp_out : output.z;
output.w = (i == 3) ? output.w + tmp_out : output.w;
} }
} }
}
output = activation_type4(output); output = activation_type4(output);
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册