From 16811a7c7a56dd27b57d1edf9f7e8a177a16149f Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Wed, 31 Jul 2019 00:01:15 +0800 Subject: [PATCH] support group conv in opencl (#1776) * support group conv in opencl * fix style --- .../kernel/cl/cl-kernel-func/conv_func.cpp | 21 + .../kernel/cl/cl_kernel/conv_kernel.inc.cl | 534 ++++++++++-------- 2 files changed, 333 insertions(+), 222 deletions(-) diff --git a/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp b/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp index 179807c724..b932e5838f 100644 --- a/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp +++ b/src/operators/kernel/cl/cl-kernel-func/conv_func.cpp @@ -55,6 +55,8 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, int input_height = param.Input()->dims()[2]; int output_width = param.Output()->dims()[3]; int output_height = param.Output()->dims()[2]; + int filter_channel = param.Filter()->dims()[1]; + int input_channel = param.Input()->dims()[1]; // DLOG << " c block " << c_block; // DLOG << " w " << w; @@ -205,6 +207,25 @@ void ConvAddBnRelu(framework::CLHelper *cl_helper, status = clSetKernelArg(kernel, index++, sizeof(int), &output_height); CL_CHECK_ERRORS(status); + if (param.Filter()->dims()[2] == 3 && param.Filter()->dims()[3] == 3) { + if (filter_channel != input_channel) { + if (filter_channel != 1) { + status = + clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); + CL_CHECK_ERRORS(status); + int has_group = 1; + status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); + CL_CHECK_ERRORS(status); + } + } else { + status = clSetKernelArg(kernel, index++, sizeof(int), &filter_channel); + CL_CHECK_ERRORS(status); + int has_group = 0; + status = clSetKernelArg(kernel, index++, sizeof(int), &has_group); + CL_CHECK_ERRORS(status); + } + } + status = clEnqueueNDRangeKernel( cl_helper->CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); diff --git a/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl b/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl index 62ba1e5d2d..dbf2151104 100644 --- a/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl +++ b/src/operators/kernel/cl/cl_kernel/conv_kernel.inc.cl @@ -47,7 +47,9 @@ __kernel void conv_3x3(__private const int global_size_dim0, __private const int input_width,/* of one block */ __private const int input_height,/* of one block */ __private const int output_width, - __private const int output_height) { + __private const int output_height, + __private const int filter_channel, + __private const int has_group) { const int out_c = get_global_id(0); const int out_w = get_global_id(1); @@ -87,242 +89,330 @@ __kernel void conv_3x3(__private const int global_size_dim0, half4 output = 0.0f; #endif - half4 input[9]; + half4 input[9]; + if (has_group == 0) { + for (int i = 0; i < input_c; ++i) { + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + input[0] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); - for (int i = 0; i < input_c; ++i) { - int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); - input[0] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x - dilation, pos_in.y - dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[1] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); - input[1] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x, pos_in.y - dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[2] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); - input[2] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x + dilation, pos_in.y - dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[3] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); - input[3] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x - dilation, pos_in.y)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[4] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); - input[4] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x, pos_in.y)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[5] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); - input[5] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x + dilation, pos_in.y)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[6] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); - input[6] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x - dilation, pos_in.y + dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); - - input[7] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x, pos_in.y + dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + input[7] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); - input[8] = select(read_imageh(input_image, sampler, - (int2)(pos_in.x + dilation, pos_in.y + dilation)), - (half4)(0.0f), - (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + input[8] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); /* - for (int j = 0; j < 9; ++j) { - int2 pos_of_weight; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - half4 weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - half4 weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - half4 weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - half4 weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - } + for (int j = 0; j < 9; ++j) { + int2 pos_of_weight; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + half4 weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + half4 weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + half4 weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + half4 weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + } */ - int j = 0; - int2 pos_of_weight; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - half4 weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - half4 weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - half4 weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - half4 weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 1; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 2; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 3; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 4; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 5; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 6; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 7; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); - - j = 8; - pos_of_weight.x = i * 3 + j % 3; - pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; - weight_x = read_imageh(filter, sampler, pos_of_weight); - output.x += dot(input[j], weight_x); - - pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; - weight_y = read_imageh(filter, sampler, pos_of_weight); - output.y += dot(input[j], weight_y); - - pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; - weight_z = read_imageh(filter, sampler, pos_of_weight); - output.z += dot(input[j], weight_z); - - pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; - weight_w = read_imageh(filter, sampler, pos_of_weight); - output.w += dot(input[j], weight_w); + int j = 0; + int2 pos_of_weight; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + half4 weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + half4 weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + half4 weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + half4 weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 1; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 2; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 3; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 4; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 5; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 6; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 7; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + + j = 8; + pos_of_weight.x = i * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + 0 * 3 + j / 3; + weight_x = read_imageh(filter, sampler, pos_of_weight); + output.x += dot(input[j], weight_x); + + pos_of_weight.y = out_c * 4 * 3 + 1 * 3 + j / 3; + weight_y = read_imageh(filter, sampler, pos_of_weight); + output.y += dot(input[j], weight_y); + + pos_of_weight.y = out_c * 4 * 3 + 2 * 3 + j / 3; + weight_z = read_imageh(filter, sampler, pos_of_weight); + output.z += dot(input[j], weight_z); + + pos_of_weight.y = out_c * 4 * 3 + 3 * 3 + j / 3; + weight_w = read_imageh(filter, sampler, pos_of_weight); + output.w += dot(input[j], weight_w); + } + } else { + for (int i = 0; i < 4; i++) { + int used_input_channel_num = (out_c * 4 + i) * filter_channel; + for (int f_c = 0; f_c < filter_channel; ++f_c) { + int input_c = used_input_channel_num + f_c; + int input_block = input_c / 4; + int2 pos_in = (int2)(input_block * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y); + input[0] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[1] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[2] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y - dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height) << 15)); + input[3] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[4] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[5] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height) << 15)); + input[6] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x - dilation, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + input[7] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + input[8] = select(read_imageh(input_image, sampler, + (int2)(pos_in.x + dilation, pos_in.y + dilation)), + (half4)(0.0f), + (ushort4)((in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height) << 15)); + + half tmp_out = 0; + for (int j = 0; j < 9; j++) { + int2 pos_of_weight; + pos_of_weight.x = (f_c / 4) * 3 + j % 3; + pos_of_weight.y = out_c * 4 * 3 + i * 3 + j / 3; + half4 weight = read_imageh(filter, sampler, pos_of_weight); + int f_c_offset = f_c % 4; + half f_value; + if (f_c_offset == 0) { + f_value = weight.x; + } else if (f_c_offset == 1) { + f_value = weight.y; + } else if (f_c_offset == 2) { + f_value = weight.z; + } else if (f_c_offset == 3) { + f_value = weight.w; + } + int input_c_offset = input_c % 4; + half input_value; + if (input_c_offset == 0) { + input_value = input[j].x; + } else if (input_c_offset == 1) { + input_value = input[j].y; + } else if (input_c_offset == 2) { + input_value = input[j].z; + } else if (input_c_offset == 3) { + input_value = input[j].w; + } + tmp_out += f_value * input_value; + } + + if (i == 0) { + output.x += tmp_out; + } else if (i == 1) { + output.y += tmp_out; + } else if (i == 2) { + output.z += tmp_out; + } else if (i == 3) { + output.w += tmp_out; + } + } + } } + #ifdef BATCH_NORM output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0)); #endif -- GitLab