From 564e5ec43120d4e9fa3d50af03eb59c8759671d7 Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Wed, 29 Nov 2017 11:17:16 +0800 Subject: [PATCH] Update 1x1 opencl kernel --- mace/kernels/opencl/cl/common.h | 4 + mace/kernels/opencl/cl/conv_2d_1x1.cl | 309 ++++------------------ mace/kernels/opencl/conv_2d_opencl_1x1.cc | 104 ++------ 3 files changed, 86 insertions(+), 331 deletions(-) diff --git a/mace/kernels/opencl/cl/common.h b/mace/kernels/opencl/cl/common.h index 7c156d8d..13d99c41 100644 --- a/mace/kernels/opencl/cl/common.h +++ b/mace/kernels/opencl/cl/common.h @@ -14,4 +14,8 @@ #define CMD_TYPE_STR(cmd, type) cmd##type #define CMD_TYPE(cmd, type) CMD_TYPE_STR(cmd, type) +#define DATA_TYPE4 VEC_DATA_TYPE(DATA_TYPE, 4) +#define READ_IMAGET CMD_TYPE(read_image, CMD_DATA_TYPE) +#define WRITE_IMAGET CMD_TYPE(write_image, CMD_DATA_TYPE) + #endif // MACE_KERNELS_OPENCL_CL_COMMON_H_ diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index 56f2cedc..e3e8f2c1 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -1,149 +1,14 @@ #include -#define vec_conv_2d_1x1_s1 \ - VEC_DATA_TYPE(DATA_TYPE,4) in0 = vload4(0, input_ptr); \ - VEC_DATA_TYPE(DATA_TYPE,4) in1 = vload4(0, input_ptr + in_pixel); \ - VEC_DATA_TYPE(DATA_TYPE,4) in2 = vload4(0, input_ptr + 2 * in_pixel); \ - VEC_DATA_TYPE(DATA_TYPE,4) in3 = vload4(0, input_ptr + 3 * in_pixel); - - -#define vec_conv_2d_1x1_s2 \ - VEC_DATA_TYPE(DATA_TYPE,4) in00 = vload4(0, input_ptr); \ - VEC_DATA_TYPE(DATA_TYPE,3) in01 = vload3(0, input_ptr + 4); \ - VEC_DATA_TYPE(DATA_TYPE,4) in10 = vload4(0, input_ptr + in_pixel); \ - VEC_DATA_TYPE(DATA_TYPE,3) in11 = vload3(0, input_ptr + in_pixel + 4); \ - VEC_DATA_TYPE(DATA_TYPE,4) in20 = vload4(0, input_ptr + 2 * in_pixel); \ - VEC_DATA_TYPE(DATA_TYPE,3) in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\ - VEC_DATA_TYPE(DATA_TYPE,4) in30 = vload4(0, input_ptr + 3 * in_pixel); \ - VEC_DATA_TYPE(DATA_TYPE,3) in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \ - VEC_DATA_TYPE(DATA_TYPE,4) in0 = (VEC_DATA_TYPE(DATA_TYPE,4))(in00.s02, in01.s02); \ - VEC_DATA_TYPE(DATA_TYPE,4) in1 = (VEC_DATA_TYPE(DATA_TYPE,4))(in10.s02, in11.s02); \ - VEC_DATA_TYPE(DATA_TYPE,4) in2 = (VEC_DATA_TYPE(DATA_TYPE,4))(in20.s02, in21.s02); \ - VEC_DATA_TYPE(DATA_TYPE,4) in3 = (VEC_DATA_TYPE(DATA_TYPE,4))(in30.s02, in31.s02); - - -#define vec_conv_2d_1x1_compute_loop \ - for (int oc = 0; oc < 4; ++oc) { \ - VEC_DATA_TYPE(DATA_TYPE,4) weights = vload4(0, filter_ptr + oc * in_chan_num); \ - VEC_DATA_TYPE(DATA_TYPE,4) out = vload4(0, output_ptr + oc * out_pixel); \ - out += in0 * weights.x; \ - out += in1 * weights.y; \ - out += in2 * weights.z; \ - out += in3 * weights.w; \ - vstore4(out, 0, output_ptr + oc * out_pixel); \ - } - -#define vec_conv_2d_1x1_compute \ - VEC_DATA_TYPE(DATA_TYPE,4) weights = vload4(0, filter_ptr); \ - VEC_DATA_TYPE(DATA_TYPE,4) out = vload4(0, output_ptr); \ - out += in0 * weights.x; \ - out += in1 * weights.y; \ - out += in2 * weights.z; \ - out += in3 * weights.w; \ - vstore4(out, 0, output_ptr); - -// Supported data type: half/float -__kernel void conv_2d_1x1_v2(__global const DATA_TYPE *input, /* n, c, h, w */ - __global const DATA_TYPE *filter, /* o, i, kh, kw */ -#ifdef BIAS - __global const DATA_TYPE *bias, /* o */ -#endif /* defined(BIAS) */ - __global DATA_TYPE *output, /* n, c, h, w */ - __private const int in_chan_num, - __private const int out_chan_num, - __private const int in_height, - __private const int in_width, - __private const int out_height, - __private const int out_width) { - int batch = get_global_id(0); - int out_chan_blk = get_global_id(1); - int out_pixel_blk = get_global_id(2); - - const int in_pixel = in_height * in_width; - const int out_pixel = out_height * out_width; - - const int round_out_width = (out_width + 3) / 4; - const int out_pixel_height = out_pixel_blk / round_out_width; - const int out_pixel_width = out_pixel_blk % round_out_width; - - const int out_chan_begin = out_chan_blk * 4; - const int out_chan_end = min(out_chan_begin + 4, out_chan_num); - const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4; - const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width); - -#ifdef STRIDE_1 - const int stride = 1; -#else - const int stride = 2; -#endif - const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4; - - const int in_offset = batch * in_chan_num * in_pixel; - const int out_offset = batch * out_chan_num * out_pixel; - - const DATA_TYPE *input_base = input + in_offset + in_pixel_begin; - DATA_TYPE *output_base = output + out_offset + out_pixel_begin; - - int out_chan_len = out_chan_end - out_chan_begin; - int pixel_len = out_pixel_end - out_pixel_begin; - - for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - DATA_TYPE *output_ptr = output_base + out_chan * out_pixel; -#ifdef BIAS - DATA_TYPE bias_value = bias[out_chan]; -#else - DATA_TYPE bias_value = 0; -#endif - for (int p = 0; p < pixel_len; ++p) { - output_ptr[p] = bias_value; - } - } - - int in_chan = 0; - if (pixel_len == 4) { - for (; in_chan + 3 < in_chan_num; in_chan += 4) { - const DATA_TYPE *input_ptr = input_base + in_chan * in_pixel; - int out_chan = out_chan_begin; - for (; out_chan + 3 < out_chan_end; out_chan += 4) { - const DATA_TYPE* filter_ptr = filter + out_chan * in_chan_num + in_chan; - DATA_TYPE *output_ptr = output_base + out_chan * out_pixel; -#ifdef STRIDE_1 - vec_conv_2d_1x1_s1; -#else - vec_conv_2d_1x1_s2; -#endif - vec_conv_2d_1x1_compute_loop; - } - for (; out_chan < out_chan_end; ++out_chan) { - const DATA_TYPE* filter_ptr = filter + out_chan * in_chan_num + in_chan; - DATA_TYPE *output_ptr = output_base + out_chan * out_pixel; -#ifdef STRIDE_1 - vec_conv_2d_1x1_s1; -#else - vec_conv_2d_1x1_s2; -#endif - vec_conv_2d_1x1_compute; - } - } - } - - for (; in_chan < in_chan_num; ++in_chan) { - const DATA_TYPE *input_ptr = input_base + in_chan * in_pixel; - for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - DATA_TYPE weights = filter[out_chan * in_chan_num + in_chan]; - DATA_TYPE *output_ptr = output_base + out_chan * out_pixel; - - for (int p = 0; p < pixel_len; ++p) { - float in = input_ptr[p*stride]; - output_ptr[p] += in * weights; - } - } - } -} - __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * cin, cout/4 */ +#ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ +#endif +#ifdef FUSED_BATCH_NORM + __read_only image2d_t bn_scale, /* cout%4 * cout/4 */ + __read_only image2d_t bn_offset, /* cout%4 * cout/4 */ +#endif __write_only image2d_t output, __private const int in_ch_blks, __private const int width) { @@ -154,12 +19,14 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - half4 bias_value = read_imageh(bias, sampler, (int2)(out_ch_blk, 0)); - half4 out[4]; - out[0] = (half4)(bias_value.x); - out[1] = (half4)(bias_value.y); - out[2] = (half4)(bias_value.z); - out[3] = (half4)(bias_value.w); + DATA_TYPE4 out[4] = {0}; +#ifdef BIAS + out[0] = + READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0)); + out[1] = out[0]; + out[2] = out[0]; + out[3] = out[0]; +#endif int w[4]; w[0] = out_w_blk; @@ -170,135 +37,75 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] // Unrolling this loop hurt perfmance int in_x_base = 0; for (int in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { - half4 in[4]; - in[0] = read_imageh(input, sampler, (int2)(in_x_base + w[0], out_hb)); + DATA_TYPE4 in[4]; + in[0] = READ_IMAGET(input, sampler, (int2)(in_x_base + w[0], out_hb)); if (w[1] < width) { // conditional load hurt perf, this branching helps sometimes - in[1] = read_imageh(input, sampler, (int2)(in_x_base + w[1], out_hb)); - in[2] = read_imageh(input, sampler, (int2)(in_x_base + w[2], out_hb)); - in[3] = read_imageh(input, sampler, (int2)(in_x_base + w[3], out_hb)); + in[1] = READ_IMAGET(input, sampler, (int2)(in_x_base + w[1], out_hb)); + in[2] = READ_IMAGET(input, sampler, (int2)(in_x_base + w[2], out_hb)); + in[3] = READ_IMAGET(input, sampler, (int2)(in_x_base + w[3], out_hb)); } - // The order matters, load input first then load filter, why? const int filter_x0 = in_ch_blk << 2; - half4 weights[4]; + DATA_TYPE4 weights[4]; #pragma unroll for (int c = 0; c < 4; ++c) { - weights[c] = read_imageh(filter, sampler, (int2)(filter_x0 + c, out_ch_blk)); + weights[c] = READ_IMAGET(filter, sampler, (int2)(filter_x0 + c, out_ch_blk)); } // Will prefetch L2 improve performance? How to pretch image data? // Interleaving load and mul does not improve performance as expected #pragma unroll - for (int c = 0; c < 4; ++c) { - out[c] += in[c].x * weights[0]; - out[c] += in[c].y * weights[1]; - out[c] += in[c].z * weights[2]; - out[c] += in[c].w * weights[3]; + for (int wi = 0; wi < 4; ++wi) { + out[wi] += in[wi].x * weights[0]; + out[wi] += in[wi].y * weights[1]; + out[wi] += in[wi].z * weights[2]; + out[wi] += in[wi].w * weights[3]; } in_x_base += width; } - const int out_x_base = out_ch_blk * width; - write_imageh(output, (int2)(out_x_base + w[0], out_hb), out[0]); - - if (w[1] >= width) return; - write_imageh(output, (int2)(out_x_base + w[1], out_hb), out[1]); - - if (w[2] >= width) return; - write_imageh(output, (int2)(out_x_base + w[2], out_hb), out[2]); - - if (w[3] >= width) return; - write_imageh(output, (int2)(out_x_base + w[3], out_hb), out[3]); -} - -__kernel void conv_2d_1x1_h8(__read_only image2d_t input, /* [c%8 * w * c/8, h * b] */ - __read_only image2d_t filter, /* cout%8 * cin, cout/8 */ - __read_only image2d_t bias, /* cout%8 * cout/8 */ - __write_only image2d_t output, - __private const int in_ch_blks, - __private const int width) { - const int out_ch_blk = get_global_id(0); - const int out_w_blk = get_global_id(1); - const int out_w_blks = get_global_size(1); - const int out_hb = get_global_id(2); - - const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - - float4 bias_value = read_imagef(bias, sampler, (int2)(out_ch_blk, 0)); - half4 bias_value03 = as_half4(bias_value.xy); - half4 bias_value47 = as_half4(bias_value.zw); - half4 out[8]; - out[0] = (half4)(bias_value03.x); - out[1] = (half4)(bias_value03.y); - out[2] = (half4)(bias_value03.z); - out[3] = (half4)(bias_value03.w); - out[4] = (half4)(bias_value47.x); - out[5] = (half4)(bias_value47.y); - out[6] = (half4)(bias_value47.z); - out[7] = (half4)(bias_value47.w); - - int w[4]; - w[0] = out_w_blk; - w[1] = w[0] + out_w_blks; - w[2] = w[1] + out_w_blks; - w[3] = w[2] + out_w_blks; - - // Unrolling this loop hurt perfmance - int in_x_base = 0; - for (int in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { - half4 in[8]; - #pragma unroll - for (int wi = 0; wi < 4; ++wi) { - float4 in_value = read_imagef(input, sampler, (int2)(in_x_base + w[0], out_hb)); - in[wi << 1] = as_half4(in_value.xy); - in[wi << 1 + 1] = as_half4(in_value.zw); - } - - // The order matters, load input first then load filter, why? - const int filter_x0 = in_ch_blk << 2; - half4 weights[8]; - #pragma unroll - for (int wi = 0; wi < 4; ++wi) { - float4 weights_value = read_imagef(filter, sampler, (int2)(filter_x0 + wi, out_ch_blk)); - weights[wi << 1] = as_half4(weights_value.xy); - weights[wi << 1 + 1] = as_half4(weights_value.zw); - } - // Will prefetch L2 improve performance? How to pretch image data? - - // Interleaving load and mul does not improve performance as expected - #pragma unroll - for (int wi = 0; wi < 4; ++wi) { - int idx = wi << 1; - out[idx] += in[idx].x * weights[0]; - out[idx] += in[idx].y * weights[1]; - out[idx] += in[idx].z * weights[2]; - out[idx] += in[idx].w * weights[3]; - - ++idx; - out[idx] += in[idx].x * weights[4]; - out[idx] += in[idx].y * weights[5]; - out[idx] += in[idx].z * weights[6]; - out[idx] += in[idx].w * weights[7]; - } +#ifdef FUSED_BATCH_NORM + // batch norm + DATA_TYPE4 bn_scale_value = + READ_IMAGET(bn_scale, sampler, (int2)(out_ch_blk, 0)); + DATA_TYPE4 scale[4]; + scale[0] = (DATA_TYPE4)(bn_scale_value.x); + scale[1] = (DATA_TYPE4)(bn_scale_value.y); + scale[2] = (DATA_TYPE4)(bn_scale_value.z); + scale[3] = (DATA_TYPE4)(bn_scale_value.w); + DATA_TYPE4 bn_offset_value = + READ_IMAGET(bn_offset, sampler, (int2)(out_ch_blk, 0)); + DATA_TYPE4 offset[4]; + offset[0] = (DATA_TYPE4)(bn_offset_value.x); + offset[1] = (DATA_TYPE4)(bn_offset_value.y); + offset[2] = (DATA_TYPE4)(bn_offset_value.z); + offset[3] = (DATA_TYPE4)(bn_offset_value.w); + + #pragma unroll + for (int wi = 0; wi < 4; ++wi) { + out[wi] = out[wi] * scale[wi] + offset[wi]; + } +#endif - in_x_base += width; +#ifdef FUSED_RELU + #pragma unroll + for (int wi = 0; wi < 4; ++wi) { + // TODO relux + out[wi] = fmax(out[wi], 0); } +#endif const int out_x_base = out_ch_blk * width; - float4 out_value = (float4)(as_float2(out[0]), as_float2(out[1])); - write_imagef(output, (int2)(out_x_base + w[0], out_hb), out_value); + WRITE_IMAGET(output, (int2)(out_x_base + w[3], out_hb), out[0]); if (w[1] >= width) return; - out_value = (float4)(as_float2(out[2]), as_float2(out[3])); - write_imagef(output, (int2)(out_x_base + w[0], out_hb), out_value); + WRITE_IMAGET(output, (int2)(out_x_base + w[1], out_hb), out[1]); if (w[2] >= width) return; - out_value = (float4)(as_float2(out[4]), as_float2(out[5])); - write_imagef(output, (int2)(out_x_base + w[0], out_hb), out_value); + WRITE_IMAGET(output, (int2)(out_x_base + w[3], out_hb), out[2]); if (w[3] >= width) return; - out_value = (float4)(as_float2(out[6]), as_float2(out[7])); - write_imagef(output, (int2)(out_x_base + w[0], out_hb), out_value); + WRITE_IMAGET(output, (int2)(out_x_base + w[3], out_hb), out[3]); } diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 28f57f48..4198cf78 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -11,75 +11,31 @@ namespace mace { namespace kernels { -void Conv1x1V2(const Tensor *input, - const Tensor *filter, - const Tensor *bias, - const int stride, - Tensor *output) { +void Conv1x1(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + const int stride, + Tensor *output) { const index_t batch = output->dim(0); - const index_t channels = output->dim(1); - const index_t height = output->dim(2); - const index_t width = output->dim(3); - const index_t input_channels = input->dim(1); - - auto runtime = OpenCLRuntime::Get(); - auto program = runtime->program(); - const index_t channel_blocks = (channels + 3) / 4; - const index_t pixel_blocks = (width + 3) / 4 * height; - - // TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy - // TODO check wired clReleaseCommandQueue latency - // The KernelFunctor can cause segment faults in cb_retain_event - std::set built_options; - built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); - built_options.emplace(stride == 1 ? "-DSTRIDE_1" : ""); - built_options.emplace(bias != nullptr ? "-DBIAS" : ""); - auto conv_2d_kernel = runtime->BuildKernel("conv_2d_1x1", "conv_2d_1x1_v2", built_options); - - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel); - uint32_t idx = 0; - conv_2d_kernel.setArg(idx++, - *(static_cast(input->buffer()))); - conv_2d_kernel.setArg(idx++, - *(static_cast(filter->buffer()))); - if (bias != nullptr) { - conv_2d_kernel.setArg(idx++, - *(static_cast(bias->buffer()))); - } - conv_2d_kernel.setArg(idx++, *(static_cast(output->buffer()))); - conv_2d_kernel.setArg(idx++, static_cast(input_channels)); - conv_2d_kernel.setArg(idx++, static_cast(channels)); - conv_2d_kernel.setArg(idx++, static_cast(input->dim(2))); - conv_2d_kernel.setArg(idx++, static_cast(input->dim(3))); - conv_2d_kernel.setArg(idx++, static_cast(height)); - conv_2d_kernel.setArg(idx++, static_cast(width)); - - auto command_queue = runtime->command_queue(); - cl_int error = command_queue.enqueueNDRangeKernel( - conv_2d_kernel, cl::NullRange, - cl::NDRange(static_cast(batch), static_cast(channel_blocks), - static_cast(pixel_blocks)), - cl::NDRange(1, 2, kwg_size / 2), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); - MACE_CHECK(error == CL_SUCCESS, error); -} - -void Conv1x1V3(const Tensor *input, - const Tensor *filter, - const Tensor *bias, - const int stride, - Tensor *output) { - const index_t batch = output->dim(0); - const index_t channels = output->dim(1); - const index_t height = output->dim(2); - const index_t width = output->dim(3); - const index_t input_channels = input->dim(1); + const index_t height = output->dim(1); + const index_t width = output->dim(2); + const index_t channels = output->dim(3); + const index_t input_batch = input->dim(0); + const index_t input_height = input->dim(1); + const index_t input_width = input->dim(2); + const index_t input_channels = input->dim(3); const index_t channel_blocks = RoundUpDiv4(channels); + const index_t width_blocks = RoundUpDiv4(width); const index_t input_channel_blocks = RoundUpDiv4(input_channels); + MACE_CHECK(stride == 1); + MACE_CHECK(input_batch == batch); + MACE_CHECK(stride != 1 || (input_height == height && input_width == width)); + std::set built_options; built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); + built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); built_options.emplace("-DSTRIDE_1"); built_options.emplace(bias != nullptr ? "-DBIAS" : ""); @@ -103,10 +59,11 @@ void Conv1x1V3(const Tensor *input, cl_int error; error = command_queue.enqueueNDRangeKernel( conv_2d_kernel, cl::NullRange, - cl::NDRange(static_cast(channel_blocks), static_cast(height), + cl::NDRange(static_cast(channel_blocks), + static_cast(width_blocks), static_cast(height * batch)), - cl::NDRange(4, 15, 8), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + cl::NDRange(4, 15, 8), // TODO auto tuning + nullptr, OpenCLRuntime::Get()->GetDefaultEvent()); MACE_CHECK(error == CL_SUCCESS, error); } @@ -115,18 +72,7 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *bias, const int *padding, Tensor *output) { - const index_t batch = output->dim(0); - const index_t height = output->dim(2); - const index_t width = output->dim(3); - - const index_t input_batch = input->dim(0); - const index_t input_height = input->dim(2); - const index_t input_width = input->dim(3); - - MACE_CHECK(input_batch == batch && input_height == height && - input_width == width); - - Conv1x1V2(input, filter, bias, 1, output); + Conv1x1(input, filter, bias, 1, output); }; extern void Conv2dOpenclK1x1S2(const Tensor *input, @@ -134,9 +80,7 @@ extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *bias, const int *padding, Tensor *output) { - MACE_CHECK(input->dim(0) == output->dim(0)); - - Conv1x1V2(input, filter, bias, 2, output); + Conv1x1(input, filter, bias, 2, output); }; } // namespace kernels -- GitLab