From 0e078d19b27c6dc7aebdba984943db15c054893b Mon Sep 17 00:00:00 2001 From: yejianwu Date: Mon, 26 Mar 2018 16:22:52 +0800 Subject: [PATCH] compatible with opencl1.1 and 1.2 --- mace/kernels/opencl/activation_opencl.cc | 17 +++- mace/kernels/opencl/addn.cc | 14 ++- mace/kernels/opencl/batch_norm_opencl.cc | 17 +++- mace/kernels/opencl/bias_add_opencl.cc | 21 +++- mace/kernels/opencl/buffer_to_image.cc | 17 +++- mace/kernels/opencl/channel_shuffle.cc | 20 ++-- mace/kernels/opencl/cl/activation.cl | 12 ++- mace/kernels/opencl/cl/addn.cl | 5 +- mace/kernels/opencl/cl/batch_norm.cl | 12 ++- mace/kernels/opencl/cl/bias_add.cl | 12 ++- mace/kernels/opencl/cl/buffer_to_image.cl | 95 ++++++++++++++++--- mace/kernels/opencl/cl/channel_shuffle.cl | 12 ++- mace/kernels/opencl/cl/concat.cl | 24 ++++- mace/kernels/opencl/cl/conv_2d.cl | 13 ++- mace/kernels/opencl/cl/conv_2d_1x1.cl | 13 ++- mace/kernels/opencl/cl/conv_2d_3x3.cl | 13 ++- mace/kernels/opencl/cl/depthwise_conv2d.cl | 22 ++++- mace/kernels/opencl/cl/eltwise.cl | 5 +- mace/kernels/opencl/cl/fully_connected.cl | 21 +++- mace/kernels/opencl/cl/matmul.cl | 6 +- mace/kernels/opencl/cl/pooling.cl | 11 ++- mace/kernels/opencl/cl/resize_bilinear.cl | 14 ++- mace/kernels/opencl/cl/slice.cl | 13 ++- mace/kernels/opencl/cl/softmax.cl | 14 ++- mace/kernels/opencl/cl/space_to_batch.cl | 18 +++- mace/kernels/opencl/cl/winograd_transform.cl | 20 +++- mace/kernels/opencl/concat.cc | 39 +++++--- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 17 +++- mace/kernels/opencl/conv_2d_opencl_3x3.cc | 18 +++- mace/kernels/opencl/conv_2d_opencl_general.cc | 18 +++- mace/kernels/opencl/depthwise_conv_opencl.cc | 18 +++- mace/kernels/opencl/eltwise_opencl.cc | 14 ++- mace/kernels/opencl/fully_connected_opencl.cc | 29 ++++-- mace/kernels/opencl/helper.cc | 32 ++++--- mace/kernels/opencl/matmul.cc | 17 ++-- mace/kernels/opencl/pooling_opencl.cc | 44 ++++++--- mace/kernels/opencl/resize_bilinear_opencl.cc | 17 +++- mace/kernels/opencl/slice.cc | 11 ++- mace/kernels/opencl/softmax_opencl.cc | 18 ++-- mace/kernels/opencl/space_to_batch_opencl.cc | 19 ++-- mace/kernels/opencl/winograd_transform.cc | 32 +++++-- 41 files changed, 614 insertions(+), 190 deletions(-) diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 9792cae5..dfe703dd 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -24,8 +24,9 @@ void ActivationFunctor::operator()(const Tensor *input, const index_t channel_blocks = RoundUpDiv4(channels); + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("activation"); @@ -60,6 +61,10 @@ void ActivationFunctor::operator()(const Tensor *input, kernel_ = runtime->BuildKernel("activation", kernel_name, built_options); } + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + if (!IsVecEqual(input_shape_, input->shape())) { int idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); @@ -69,14 +74,16 @@ void ActivationFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, static_cast(relux_max_limit_)); kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::string tuning_key = Concat(tuning_key_prefix_, output->dim(0), output->dim(1), output->dim(2), output->dim(3)); diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index e7869bb2..94538fc2 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -24,6 +24,8 @@ void AddNFunctor::operator()( const index_t width = input_tensors[0]->dim(2); const index_t channels = input_tensors[0]->dim(3); + auto runtime = OpenCLRuntime::Global(); + for (int i = 1; i < size; ++i) { MACE_CHECK_NOTNULL(input_tensors[i]); MACE_CHECK(batch == input_tensors[i]->dim(0)); @@ -36,7 +38,6 @@ void AddNFunctor::operator()( if (input_tensors.size() > 4) { MACE_NOT_IMPLEMENTED; } - auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("addn"); @@ -53,6 +54,9 @@ void AddNFunctor::operator()( const index_t width_pixels = channel_blocks * width; const index_t batch_height_pixels = batch * height; + const uint32_t gws[2] = {static_cast(width_pixels), + static_cast(batch_height_pixels)}; + if (!IsVecEqual(input_shape_, input_tensors[0]->shape())) { std::vector output_image_shape; CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, @@ -64,13 +68,15 @@ void AddNFunctor::operator()( kernel_.setArg(idx++, *(input->opencl_image())); } kernel_.setArg(idx++, *(output_tensor->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensors[0]->shape(); } - const uint32_t gws[2] = {static_cast(width_pixels), - static_cast(batch_height_pixels)}; - const std::vector lws = {64, 16, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {kwg_size / 16, 16, 1}; std::stringstream ss; ss << "addn_opencl_kernel_" << output_shape[0] << "_" << output_shape[1] << "_" << output_shape[2] << "_" << output_shape[3]; diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index d9dfb825..d79b5c18 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -30,8 +30,13 @@ void BatchNormFunctor::operator()(const Tensor *input, const index_t channel_blocks = RoundUpDiv4(channels); + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("batch_norm"); @@ -74,14 +79,16 @@ void BatchNormFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit_); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::string tuning_key = Concat("batch_norm_opencl_kernel_", activation_, output->dim(0), output->dim(1), output->dim(2), output->dim(3), folded_constant_); diff --git a/mace/kernels/opencl/bias_add_opencl.cc b/mace/kernels/opencl/bias_add_opencl.cc index 3d4c4ec5..69327995 100644 --- a/mace/kernels/opencl/bias_add_opencl.cc +++ b/mace/kernels/opencl/bias_add_opencl.cc @@ -23,6 +23,10 @@ void BiasAddFunctor::operator()(const Tensor *input, const index_t channel_blocks = RoundUpDiv4(channels); + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + auto runtime = OpenCLRuntime::Global(); if (kernel_.get() == nullptr) { std::set built_options; @@ -38,17 +42,24 @@ void BiasAddFunctor::operator()(const Tensor *input, kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(bias->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8}; + + std::vector roundup_gws(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundup_gws[i] = RoundUp(gws[i], lws[i]); + } cl::Event event; cl_int error = runtime->command_queue().enqueueNDRangeKernel( - kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + kernel_, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]), cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS); if (future != nullptr) { diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 6d8f3ef1..9fee7a95 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -26,7 +26,8 @@ void BufferToImageFunctor::operator()( buffer->Resize(image->shape()); } - size_t gws[2] = {image_shape[0], image_shape[1]}; + uint32_t gws[2] = {static_cast(image_shape[0]), + static_cast(image_shape[1])}; std::string kernel_name; switch (type) { case CONV2D_FILTER: @@ -98,10 +99,20 @@ void BufferToImageFunctor::operator()( b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } b2f_kernel.setArg(idx++, *(image->opencl_image())); - const std::vector lws = {16, 64}; + b2f_kernel.setArg(idx++, gws[0]); + b2f_kernel.setArg(idx++, gws[1]); + + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(b2f_kernel)); + const std::vector lws = {16, kwg_size / 16}; + std::vector roundup_gws(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundup_gws[i] = RoundUp(gws[i], lws[i]); + } + cl::Event event; cl_int error = runtime->command_queue().enqueueNDRangeKernel( - b2f_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]), + b2f_kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]), cl::NDRange(lws[0], lws[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; diff --git a/mace/kernels/opencl/channel_shuffle.cc b/mace/kernels/opencl/channel_shuffle.cc index 78d855e2..34bc5784 100644 --- a/mace/kernels/opencl/channel_shuffle.cc +++ b/mace/kernels/opencl/channel_shuffle.cc @@ -30,9 +30,13 @@ void ChannelShuffleFunctor::operator()( "groups must be multiple of 4"); const index_t group_channel_blocks = RoundUpDiv4(channels_per_group); - if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); + const uint32_t gws[3] = {static_cast(group_channel_blocks), + static_cast(width), + static_cast(height * batch)}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("channel_shuffle"); built_options.emplace("-Dchannel_shuffle=" + kernel_name); @@ -42,19 +46,23 @@ void ChannelShuffleFunctor::operator()( kernel_ = runtime->BuildKernel("channel_shuffle", kernel_name, built_options); } + if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, groups_); kernel_.setArg(idx++, static_cast(channels_per_group)); kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); } - const uint32_t gws[3] = {static_cast(group_channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "channel_shuffle_opencl_kernel_" << output->dim(0) << "_" diff --git a/mace/kernels/opencl/cl/activation.cl b/mace/kernels/opencl/cl/activation.cl index bee0b0e3..23e6d60e 100644 --- a/mace/kernels/opencl/cl/activation.cl +++ b/mace/kernels/opencl/cl/activation.cl @@ -5,11 +5,19 @@ __kernel void activation(__read_only image2d_t input, __read_only image2d_t alpha, #endif __private const float relux_max_limit, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); - const int width = get_global_size(1); + if (ch_blk >= global_size_dim0 || w >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; const int pos = mad24(ch_blk, width, w); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); diff --git a/mace/kernels/opencl/cl/addn.cl b/mace/kernels/opencl/cl/addn.cl index 9504d12a..4279fc23 100644 --- a/mace/kernels/opencl/cl/addn.cl +++ b/mace/kernels/opencl/cl/addn.cl @@ -8,9 +8,12 @@ __kernel void addn(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ #if INPUT_NUM > 3 __read_only image2d_t input3, #endif - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int w = get_global_id(0); const int hb = get_global_id(1); + if (w >= global_size_dim0 || hb >= global_size_dim1) return; DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb)); DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb)); diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 773b59c4..5899fb00 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -9,11 +9,19 @@ __kernel void batch_norm(__read_only image2d_t input, __private const float epsilon, #endif __write_only image2d_t output, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); - const int width = get_global_size(1); + if (ch_blk >= global_size_dim0 || w >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; #ifdef FOLDED_CONSTANT DATA_TYPE4 bn_scale = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0)); diff --git a/mace/kernels/opencl/cl/bias_add.cl b/mace/kernels/opencl/cl/bias_add.cl index f5180a3c..d139652b 100644 --- a/mace/kernels/opencl/cl/bias_add.cl +++ b/mace/kernels/opencl/cl/bias_add.cl @@ -2,11 +2,19 @@ // Supported data types: half/float __kernel void bias_add(__read_only image2d_t input, __read_only image2d_t bias, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); - const int width = get_global_size(1); + if (ch_blk >= global_size_dim0 || w >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; const int pos = mad24(ch_blk, width, w); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index 781d21e3..faf1f091 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -5,9 +5,15 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, o __private const int filter_w, __private const int out_channel, __private const int in_channel, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int out_channel_idx = h * 4; const int rounded_in_channel = ((in_channel + 3) / 4) * 4; const int hw_idx = w / rounded_in_channel; @@ -45,9 +51,15 @@ __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic __private const int filter_w, __private const int out_channel, __private const int in_channel, - __read_only image2d_t input) { + __read_only image2d_t input, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int out_channel_idx = h * 4; const int rounded_in_channel = ((in_channel + 3) / 4) * 4; const int hw_idx = w / rounded_in_channel; @@ -84,9 +96,14 @@ __kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w __private const int filter_w, __private const int in_channel, __private const int multiplier, - __write_only image2d_t output) { /* ic%4 * kh * kw * m, ic/4 */ + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { /* ic%4 * kh * kw * m, ic/4 */ const int w = get_global_id(0); const int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } DATA_TYPE4 values = 0; if (multiplier == 1) { @@ -134,9 +151,15 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ __private const int height, __private const int width, __private const int channels, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int batch_idx = h / height; const int height_idx = h % height; const int width_idx = w % width; @@ -166,9 +189,15 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ __private const int height, __private const int width, __private const int channels, - __read_only image2d_t input) { + __read_only image2d_t input, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int batch_idx = h / height; const int height_idx = h % height; const int width_idx = w % width; @@ -196,9 +225,14 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ __private const int input_offset, __private const int count, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } const int offset = input_offset + w * 4; const int size = count - w * 4; @@ -223,9 +257,14 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ __private const int count, - __read_only image2d_t input) { + __read_only image2d_t input, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } const int offset = w * 4; int2 coord = (int2)(w, h); @@ -251,9 +290,15 @@ __kernel void in_out_height_buffer_to_image(__global const DATA_TYPE *input, //n __private const int height, __private const int width, __private const int channels, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int wc = width * channels; const int height_blks = (height + 3) / 4; const int batch_idx = h / height_blks; @@ -284,9 +329,15 @@ __kernel void in_out_height_image_to_buffer(__global DATA_TYPE *output, //nhwc __private const int height, __private const int width, __private const int channels, - __read_only image2d_t input) { + __read_only image2d_t input, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int height_blks = (height + 3) / 4; const int batch_idx = h / height_blks; const int height_idx = (h % height_blks) << 2; @@ -315,9 +366,15 @@ __kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* n __private const int height, __private const int width, __private const int channels, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int width_blks = (width + 3) / 4; const int batch_idx = h / height; const int height_idx = h % height; @@ -349,10 +406,16 @@ __kernel void winograd_filter_buffer_to_image(__global const DATA_TYPE *input, / __private const int in_channels, __private const int height, __private const int width, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { int w = get_global_id(0); int h = get_global_id(1); - const int out_channels = get_global_size(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + + const int out_channels = global_size_dim1; const int out_channel_idx = h; const int in_channel_idx = w << 2; const int offset = input_offset + (out_channel_idx * in_channels + in_channel_idx) * height * width; @@ -429,9 +492,15 @@ __kernel void winograd_filter_image_to_buffer(__global DATA_TYPE *output, //Oc, __private const int height, __private const int width, __private const int channel, - __read_only image2d_t input) { + __read_only image2d_t input, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int w = get_global_id(0); const int h = get_global_id(1); + if (w >= global_size_dim0 || h >= global_size_dim1) { + return; + } + const int width_idx = w << 2; const int size = width - width_idx; int offset = h * width + width_idx; diff --git a/mace/kernels/opencl/cl/channel_shuffle.cl b/mace/kernels/opencl/cl/channel_shuffle.cl index 2a193a23..6437ee7f 100644 --- a/mace/kernels/opencl/cl/channel_shuffle.cl +++ b/mace/kernels/opencl/cl/channel_shuffle.cl @@ -4,11 +4,19 @@ __kernel void channel_shuffle(__read_only image2d_t input, __private const int groups, __private const int channels_per_group, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int group_chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); - const int width = get_global_size(1); const int hb_idx = get_global_id(2); + if (group_chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; const int group_blks = groups / 4; const int groups_blks_width = group_blks * width; const int channels_per_group_blks = channels_per_group / 4; diff --git a/mace/kernels/opencl/cl/concat.cl b/mace/kernels/opencl/cl/concat.cl index af13422d..ac74f0f2 100644 --- a/mace/kernels/opencl/cl/concat.cl +++ b/mace/kernels/opencl/cl/concat.cl @@ -25,11 +25,19 @@ DATA_TYPE4 stitch_vector(DATA_TYPE4 left, __kernel void concat_channel(__read_only image2d_t input0, __read_only image2d_t input1, __private const int input0_chan, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); - const int width = get_global_size(1); const int hb_idx = get_global_id(2); + if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; const int input0_chan_blk = (input0_chan + 3) >> 2; DATA_TYPE4 data = 0; @@ -74,11 +82,19 @@ __kernel void concat_channel(__read_only image2d_t input0, // Required: All input channels are divisible by 4 __kernel void concat_channel_multi(__read_only image2d_t input, __private const int chan_blk_offset, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); - const int width = get_global_size(1); const int hb_idx = get_global_id(2); + if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; DATA_TYPE4 data = 0; data = READ_IMAGET(input, SAMPLER, diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 8ed3073f..75be47f1 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -18,11 +18,20 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __private const int padding_top, __private const int padding_left, __private const int dilation_h, - __private const int dilation_w) { + __private const int dilation_w, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); - const int out_w_blks = get_global_size(1); const int out_hb = get_global_id(2); + + if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1 + || out_hb >= global_size_dim2) { + return; + } + + const int out_w_blks = global_size_dim1; const int rounded_in_ch = in_ch_blks << 2; #ifdef BIAS diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index b695165e..a9e4f95f 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -12,12 +12,21 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __private const int in_ch_blks, __private const int height, __private const int width, - __private const int stride) { + __private const int stride, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); - const int out_w_blks = get_global_size(1); const int out_hb = get_global_id(2); + if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1 + || out_hb >= global_size_dim2) { + return; + } + + const int out_w_blks = global_size_dim1; + #ifdef BIAS DATA_TYPE4 out0 = READ_IMAGET(bias, SAMPLER, (int2)(out_ch_blk, 0)); DATA_TYPE4 out1 = out0; diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index fad561aa..b2d8eaa4 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -16,11 +16,20 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __private const int padding_top, __private const int padding_left, __private const int dilation_h, - __private const int dilation_w) { + __private const int dilation_w, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); - const int out_w_blks = get_global_size(1); const int out_hb = get_global_id(2); + + if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1 + || out_hb >= global_size_dim2) { + return; + } + + const int out_w_blks = global_size_dim1; const int rounded_in_ch = in_ch_blks << 2; #ifdef BIAS diff --git a/mace/kernels/opencl/cl/depthwise_conv2d.cl b/mace/kernels/opencl/cl/depthwise_conv2d.cl index 792c0934..28125a8d 100644 --- a/mace/kernels/opencl/cl/depthwise_conv2d.cl +++ b/mace/kernels/opencl/cl/depthwise_conv2d.cl @@ -18,11 +18,19 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h __private const short padding_top, __private const short padding_left, __private const short dilation_h, - __private const short dilation_w) { + __private const short dilation_w, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const short out_ch_blk = get_global_id(0); const short out_w_blk = get_global_id(1); - const short out_w_blks = get_global_size(1); const short out_hb = get_global_id(2); + if (out_ch_blk >= global_size_dim0 || out_w_blk >= global_size_dim1 + || out_hb >= global_size_dim2) { + return; + } + + const short out_w_blks = global_size_dim1; const short rounded_in_ch = in_ch_blks << 2; const short in_ch_blk = out_ch_blk; // multiplier = 1 @@ -141,10 +149,18 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 __private const short filter_height, __private const short filter_width, __private const short padding_top, - __private const short padding_left) { + __private const short padding_left, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const short out_ch_blk = get_global_id(0); const short out_w_blk = get_global_id(1) << 2; const short out_hb = get_global_id(2); + if (out_ch_blk >= global_size_dim0 || get_global_id(1) >= global_size_dim1 + || out_hb >= global_size_dim2) { + return; + } + const short rounded_in_ch = in_ch_blks << 2; const short in_ch_blk = out_ch_blk; // multiplier = 1 diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index 735bc96e..edfb777d 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -6,9 +6,12 @@ __kernel void eltwise(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ __private const float coeff0, __private const float coeff1, #endif - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int w = get_global_id(0); const int hb = get_global_id(1); + if (w >= global_size_dim0 || hb >= global_size_dim1) return; DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb)); DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb)); diff --git a/mace/kernels/opencl/cl/fully_connected.cl b/mace/kernels/opencl/cl/fully_connected.cl index 057a66a4..90d84c11 100644 --- a/mace/kernels/opencl/cl/fully_connected.cl +++ b/mace/kernels/opencl/cl/fully_connected.cl @@ -10,9 +10,15 @@ __kernel void fully_connected(__read_only image2d_t input, __private const int input_height, __private const int input_width, __private const int input_channel, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int batch_idx = get_global_id(0); const int out_blk_idx = get_global_id(1); + if (batch_idx >= global_size_dim0 || out_blk_idx >= global_size_dim1) { + return; + } + const int input_chan_blk = (input_channel + 3) >> 2; float4 input_value; @@ -68,11 +74,20 @@ __kernel void fully_connected_width(__read_only image2d_t input, __private const int input_width, __private const int in_chan_blks, __private const int out_blks, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int inter_out_idx = get_global_id(0); const int width_blk_idx = get_global_id(1); - const int width_blk_count = get_global_size(1); const int batch_out_blk_idx = get_global_id(2); + if (inter_out_idx >= global_size_dim0 || width_blk_idx >= global_size_dim1 + || batch_out_blk_idx >= global_size_dim2) { + return; + } + + const int width_blk_count = global_size_dim1; + const int batch_idx = batch_out_blk_idx / out_blks; const int out_blk_idx = batch_out_blk_idx % out_blks; diff --git a/mace/kernels/opencl/cl/matmul.cl b/mace/kernels/opencl/cl/matmul.cl index cb71f21d..f0c2ee0e 100644 --- a/mace/kernels/opencl/cl/matmul.cl +++ b/mace/kernels/opencl/cl/matmul.cl @@ -8,9 +8,13 @@ __kernel void matmul(__read_only image2d_t A, __private const int N, __private const int K, __private const int height_blocks, - __private const int k_blocks) { + __private const int k_blocks, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int gx = get_global_id(0) << 2; const int hb = get_global_id(1); + if (get_global_id(0) >= global_size_dim0 || hb >= global_size_dim1) return; + const int batch = hb / height_blocks; const int ty = (hb % height_blocks); const int gy = mad24(batch, height_blocks, ty); diff --git a/mace/kernels/opencl/cl/pooling.cl b/mace/kernels/opencl/cl/pooling.cl index f2298a93..dad48824 100644 --- a/mace/kernels/opencl/cl/pooling.cl +++ b/mace/kernels/opencl/cl/pooling.cl @@ -27,12 +27,19 @@ __kernel void pooling(__read_only image2d_t input, __private const int pad_left, __private const int stride, __private const int pooling_size, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int out_chan_idx = get_global_id(0); const int out_width_idx = get_global_id(1); - const int out_width = get_global_size(1); const int out_hb_idx = get_global_id(2); + if (out_chan_idx >= global_size_dim0 || out_width_idx >= global_size_dim1 + || out_hb_idx >= global_size_dim2) { + return; + } + const int out_width = global_size_dim1; const int batch_idx = mul24((out_hb_idx / out_height), in_height); const int in_height_start = mul24((out_hb_idx % out_height), stride) - pad_top; const int in_width_start = mul24(out_width_idx, stride) - pad_left; diff --git a/mace/kernels/opencl/cl/resize_bilinear.cl b/mace/kernels/opencl/cl/resize_bilinear.cl index e0b4b83d..b3778cb2 100644 --- a/mace/kernels/opencl/cl/resize_bilinear.cl +++ b/mace/kernels/opencl/cl/resize_bilinear.cl @@ -6,12 +6,20 @@ __kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w * __private const float width_scale, __private const int in_height, __private const int in_width, - __private const int out_height) { + __private const int out_height, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int ch_blk = get_global_id(0); - const int ch_blks = get_global_size(0); const int w = get_global_id(1); - const int out_width = get_global_size(1); const int hb = get_global_id(2); + if (ch_blk >= global_size_dim0 || w >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } + const int ch_blks = global_size_dim0; + const int out_width = global_size_dim1; + const int b = hb / out_height; const int h = hb % out_height; diff --git a/mace/kernels/opencl/cl/slice.cl b/mace/kernels/opencl/cl/slice.cl index d8d45bcb..a626c0de 100644 --- a/mace/kernels/opencl/cl/slice.cl +++ b/mace/kernels/opencl/cl/slice.cl @@ -2,11 +2,20 @@ __kernel void slice(__read_only image2d_t input, __private const int chan_blk_offset, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); - const int width = get_global_size(1); const int hb_idx = get_global_id(2); + if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + + const int width = global_size_dim1; + DATA_TYPE4 data = READ_IMAGET(input, SAMPLER, (int2)(mad24(chan_blk_idx + chan_blk_offset, width, width_idx), hb_idx)); diff --git a/mace/kernels/opencl/cl/softmax.cl b/mace/kernels/opencl/cl/softmax.cl index 6830b508..e7027394 100644 --- a/mace/kernels/opencl/cl/softmax.cl +++ b/mace/kernels/opencl/cl/softmax.cl @@ -3,12 +3,20 @@ __kernel void softmax(__read_only image2d_t input, __private const int channels, __private const int remain_channels, - __write_only image2d_t output) { + __write_only image2d_t output, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); - const int chan_blks = get_global_size(0) - 1; - const int width = get_global_size(1); + if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1 + || hb_idx >= global_size_dim2) { + return; + } + + const int chan_blks = global_size_dim0 - 1; + const int width = global_size_dim1; int pos = width_idx; DATA_TYPE max_value = -FLT_MAX; diff --git a/mace/kernels/opencl/cl/space_to_batch.cl b/mace/kernels/opencl/cl/space_to_batch.cl index 9ad63509..e36313fe 100644 --- a/mace/kernels/opencl/cl/space_to_batch.cl +++ b/mace/kernels/opencl/cl/space_to_batch.cl @@ -9,10 +9,17 @@ __kernel void space_to_batch(__read_only image2d_t space_data, __private const int space_height, __private const int space_width, __private const int batch_height, - __private const int batch_width) { + __private const int batch_width, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_idx = get_global_id(0); const int batch_w_idx = get_global_id(1); const int batch_hb_idx = get_global_id(2); + if (chan_idx >= global_size_dim0 || batch_w_idx >= global_size_dim1 + || batch_hb_idx >= global_size_dim2) { + return; + } const int batch_b_idx = batch_hb_idx / batch_height; const int batch_h_idx = batch_hb_idx % batch_height; @@ -48,10 +55,17 @@ __kernel void batch_to_space(__read_only image2d_t batch_data, __private const int space_height, __private const int space_width, __private const int batch_height, - __private const int batch_width) { + __private const int batch_width, + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2) { const int chan_idx = get_global_id(0); const int batch_w_idx = get_global_id(1); const int batch_hb_idx = get_global_id(2); + if (chan_idx >= global_size_dim0 || batch_w_idx >= global_size_dim1 + || batch_hb_idx >= global_size_dim2) { + return; + } const int batch_b_idx = batch_hb_idx / batch_height; const int batch_h_idx = batch_hb_idx % batch_height; diff --git a/mace/kernels/opencl/cl/winograd_transform.cl b/mace/kernels/opencl/cl/winograd_transform.cl index cbcd3b19..3acfc902 100644 --- a/mace/kernels/opencl/cl/winograd_transform.cl +++ b/mace/kernels/opencl/cl/winograd_transform.cl @@ -8,10 +8,16 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input, __private const int round_hw, __private const int round_w, __private const int padding_top, - __private const int padding_left) { + __private const int padding_left, + __private const int global_size_dim0, + __private const int global_size_dim1) { int out_width_idx = get_global_id(0); int chan_blk_idx = get_global_id(1); - const int chan_blk_size = get_global_size(1); + if (out_width_idx >= global_size_dim0 || chan_blk_idx >= global_size_dim1) { + return; + } + + const int chan_blk_size = global_size_dim1; const int batch_idx = out_width_idx / round_hw; const int t_idx = out_width_idx % round_hw; @@ -115,10 +121,16 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, __private const int out_width, __private const int round_hw, __private const int round_w, - __private const float relux_max_limit) { + __private const float relux_max_limit, + __private const int global_size_dim0, + __private const int global_size_dim1) { const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); - const int out_channel = get_global_size(1); + if (width_idx >= global_size_dim0 || height_idx >= global_size_dim1) { + return; + } + + const int out_channel = global_size_dim1; int width = width_idx; int height = height_idx; diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index da8671db..ccb5b6c2 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -24,9 +24,14 @@ static void Concat2(cl::Kernel *kernel, const index_t channel = output->dim(3); const int channel_blk = RoundUpDiv4(channel); + const uint32_t gws[3] = { + static_cast(channel_blk), static_cast(width), + static_cast(batch * height), + }; + + auto runtime = OpenCLRuntime::Global(); if (kernel->get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel"); built_options.emplace("-Dconcat_channel=" + kernel_name); @@ -51,14 +56,16 @@ static void Concat2(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(input0->dim(3))); kernel->setArg(idx++, *(static_cast(output->opencl_image()))); + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + *prev_input_shape = input0->shape(); } - const uint32_t gws[3] = { - static_cast(channel_blk), static_cast(width), - static_cast(batch * height), - }; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "concat_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); @@ -75,8 +82,8 @@ static void ConcatN(cl::Kernel *kernel, const index_t width = output->dim(2); const index_t channel = output->dim(3); + auto runtime = OpenCLRuntime::Global(); if (kernel->get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("concat_channel_multi"); built_options.emplace("-Dconcat_channel_multi=" + kernel_name); @@ -89,18 +96,24 @@ static void ConcatN(cl::Kernel *kernel, index_t chan_blk_offset = 0; for (int i = 0; i < inputs_count; ++i) { const Tensor *input = input_list[i]; + index_t input_channel_blk = input->dim(3) / 4; + const uint32_t gws[3] = { + static_cast(input_channel_blk), static_cast(width), + static_cast(batch * height), + }; + uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, static_cast(chan_blk_offset)); kernel->setArg(idx++, *(output->opencl_image())); + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); - index_t input_channel_blk = input->dim(3) / 4; chan_blk_offset += input_channel_blk; - const uint32_t gws[3] = { - static_cast(input_channel_blk), static_cast(width), - static_cast(batch * height), - }; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "concat_n_opencl_kernel_" << input_channel_blk << "_" << width << "_" << batch * height; diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 62f8b09a..4bfa9ac7 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -36,6 +36,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, const index_t width_blocks = RoundUpDiv4(width); const index_t input_channel_blocks = RoundUpDiv4(input_channels); + auto runtime = OpenCLRuntime::Global(); if (kernel->get() == nullptr) { MACE_CHECK(input_batch == batch); @@ -66,9 +67,13 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, LOG(FATAL) << "Unknown activation type: " << activation; } - auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d_1x1", kernel_name, built_options); } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width_blocks), + static_cast(height * batch)}; + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); @@ -85,14 +90,16 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); kernel->setArg(idx++, stride); + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width_blocks), - static_cast(height * batch)}; - const std::vector lws = {8, 15, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::string tuning_key = Concat("conv2d_1x1_opencl_kernel_", activation, output->dim(0), output->dim(1), output->dim(2), output->dim(3)); diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index ba047cdf..97db8ab9 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -35,6 +35,8 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, const index_t input_channel_blocks = RoundUpDiv4(input_channels); const index_t width_blocks = RoundUpDiv(width); + auto runtime = OpenCLRuntime::Global(); + if (kernel->get() == nullptr) { std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d_3x3"); @@ -61,9 +63,13 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, LOG(FATAL) << "Unknown activation type: " << activation; } - auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d_3x3", kernel_name, built_options); } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width_blocks), + static_cast(height * batch)}; + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); @@ -83,14 +89,16 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width_blocks), - static_cast(height * batch)}; - const std::vector lws = {4, 15, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {4, kwg_size / 32, 8, 1}; std::string tuning_key = Concat("conv2d_3x3_opencl_kernel_", activation, output->dim(0), output->dim(1), output->dim(2), output->dim(3)); diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index fd48605f..4f1b67f6 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -35,6 +35,8 @@ extern void Conv2dOpencl(cl::Kernel *kernel, const index_t input_channel_blocks = RoundUpDiv4(input_channels); const index_t width_blocks = RoundUpDiv4(width); + auto runtime = OpenCLRuntime::Global(); + if (kernel->get() == nullptr) { std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv_2d"); @@ -61,9 +63,13 @@ extern void Conv2dOpencl(cl::Kernel *kernel, LOG(FATAL) << "Unknown activation type: " << activation; } - auto runtime = OpenCLRuntime::Global(); *kernel = runtime->BuildKernel("conv_2d", kernel_name, built_options); } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width_blocks), + static_cast(height * batch)}; + if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); @@ -85,14 +91,16 @@ extern void Conv2dOpencl(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width_blocks), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::string tuning_key = Concat("conv2d_general_opencl_kernel_", activation, output->dim(0), output->dim(1), output->dim(2), output->dim(3)); diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index ecb109d1..18b53853 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -35,8 +35,14 @@ void DepthwiseConv2d(cl::Kernel *kernel, const index_t channel_blocks = RoundUpDiv4(channels); const index_t input_channel_blocks = RoundUpDiv4(input_channels); const index_t width_blocks = RoundUpDiv4(width); + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width_blocks), + static_cast(height * batch)}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel->get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depthwise_conv2d"); if (stride == 1 && dilations[0] == 1 && dilations[1] == 1) { @@ -104,13 +110,15 @@ void DepthwiseConv2d(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(dilations[0])); kernel->setArg(idx++, static_cast(dilations[1])); } + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width_blocks), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::string tuning_key = Concat("depthwise_conv2d_ocl_kernel_", activation, batch, height, width, channels, multiplier); TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future); diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 548d907d..a2e4e8f1 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -24,8 +24,12 @@ void EltwiseFunctor::operator()(const Tensor *input0, const index_t width_pixels = channel_blocks * width; const index_t batch_height_pixels = batch * height; + const uint32_t gws[2] = {static_cast(width_pixels), + static_cast(batch_height_pixels)}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("eltwise"); @@ -45,12 +49,14 @@ void EltwiseFunctor::operator()(const Tensor *input0, kernel_.setArg(idx++, coeff_[1]); } kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); input_shape_ = input0->shape(); } - const uint32_t gws[2] = {static_cast(width_pixels), - static_cast(batch_height_pixels)}; - const std::vector lws = {64, 16, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {kwg_size / 16, 16, 1}; std::stringstream ss; ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index f4b7b222..3e17f98f 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -75,6 +75,7 @@ void FCWXKernel(cl::Kernel *kernel, if (!IsVecEqual(*prev_input_shape, input->shape())) { const index_t batch = output->dim(0); const index_t output_blocks = RoundUpDiv4(output->dim(3)); + (*gws)[2] = static_cast(batch * output_blocks); uint32_t idx = 0; kernel->setArg(idx++, *(input->opencl_image())); @@ -90,14 +91,21 @@ void FCWXKernel(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); kernel->setArg(idx++, static_cast(output_blocks)); kernel->setArg(idx++, relux_max_limit); - - (*gws)[2] = static_cast(batch * output_blocks); + kernel->setArg(idx++, (*gws)[0]); + kernel->setArg(idx++, (*gws)[1]); + kernel->setArg(idx++, (*gws)[2]); *prev_input_shape = input->shape(); } + + std::vector roundup_gws(lws->size()); + for (size_t i = 0; i < lws->size(); ++i) { + roundup_gws[i] = RoundUp((*gws)[i], (*lws)[i]); + } + cl::Event event; cl_int error = runtime->command_queue().enqueueNDRangeKernel( - *kernel, cl::NullRange, cl::NDRange((*gws)[0], (*gws)[1], (*gws)[2]), + *kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]), cl::NDRange((*lws)[0], (*lws)[1], (*lws)[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; @@ -161,6 +169,13 @@ void FCWTXKernel(cl::Kernel *kernel, } if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; + const index_t batch = output->dim(0); + const index_t output_blocks = RoundUpDiv4(output->dim(3)); + + *gws = { + static_cast(batch), static_cast(output_blocks), + }; + kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(weight->opencl_image())); if (bias != nullptr) { @@ -172,13 +187,9 @@ void FCWTXKernel(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(input->dim(3))); // FIXME handle flexable data type: half not supported kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, (*gws)[0]); + kernel->setArg(idx++, (*gws)[1]); - const index_t batch = output->dim(0); - const index_t output_blocks = RoundUpDiv4(output->dim(3)); - - *gws = { - static_cast(batch), static_cast(output_blocks), - }; *prev_input_shape = input->shape(); } diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index ee52625a..2141c65e 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -226,12 +226,7 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, {4, kwg_size / 28, 7, 1}, {4, kwg_size / 32, 8, 1}, {4, kwg_size / 56, 14, 1}, - {3, 15, 9, 1}, - {7, 15, 9, 1}, - {9, 7, 15, 1}, - {15, 7, 9, 1}, {1, kwg_size, 1, 1}, - {4, 15, 8, 1}, }; }; cl::Event event; @@ -240,6 +235,11 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, MACE_CHECK(params.size() == 4) << "Tuning parameters of 3D kernel must be 4D"; cl_int error = CL_SUCCESS; + std::vector roundup_gws(3); + for (size_t i = 0; i < 3; ++i) { + roundup_gws[i] = RoundUp(gws[i], params[i]); + } + if (timer == nullptr) { uint32_t num_blocks = params[3]; const uint32_t block_size = gws[2] / num_blocks; @@ -247,16 +247,17 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, for (uint32_t i = 0; i < num_blocks; ++i) { uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size; + uint32_t roundup_gws2 = RoundUp(gws2, params[2]); error = runtime->command_queue().enqueueNDRangeKernel( kernel, cl::NDRange(0, 0, i * block_size), - cl::NDRange(gws[0], gws[1], gws2), + cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2), cl::NDRange(params[0], params[1], params[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; } } else { timer->ClearTiming(); error = runtime->command_queue().enqueueNDRangeKernel( - kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]), cl::NDRange(params[0], params[1], params[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; timer->AccumulateTiming(); @@ -273,9 +274,10 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, for (uint32_t i = 0; i < num_blocks; ++i) { uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size; + uint32_t roundup_gws2 = RoundUp(gws2, params[2]); error = runtime->command_queue().enqueueNDRangeKernel( kernel, cl::NDRange(0, 0, i * block_size), - cl::NDRange(gws[0], gws[1], gws2), + cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws2), cl::NDRange(params[0], params[1], params[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; timer->AccumulateTiming(); @@ -318,7 +320,6 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, {kwg_size / 64, 64, 1}, {kwg_size / 128, 128, 1}, {kwg_size / 256, 256, 1}, - {kwg_size / 512, 512, 1}, {kwg_size, 1, 1}, {1, kwg_size, 1}}; }; @@ -328,6 +329,11 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, MACE_CHECK(params.size() == 3) << "Tuning parameters of 2D kernel must be 3d"; cl_int error = CL_SUCCESS; + std::vector roundup_gws(2); + for (size_t i = 0; i < 2; ++i) { + roundup_gws[i] = RoundUp(gws[i], params[i]); + } + if (timer == nullptr) { uint32_t num_blocks = params[2]; const uint32_t block_size = gws[1] / num_blocks; @@ -335,15 +341,16 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, for (uint32_t i = 0; i < num_blocks; ++i) { uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size; + uint32_t roundup_gws1 = RoundUp(gws1, params[1]); error = runtime->command_queue().enqueueNDRangeKernel( - kernel, cl::NDRange(0, i * block_size), cl::NDRange(gws[0], gws1), + kernel, cl::NDRange(0, i * block_size), cl::NDRange(roundup_gws[0], roundup_gws1), cl::NDRange(params[0], params[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; } } else { timer->ClearTiming(); error = runtime->command_queue().enqueueNDRangeKernel( - kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]), + kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]), cl::NDRange(params[0], params[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; timer->AccumulateTiming(); @@ -360,8 +367,9 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, for (uint32_t i = 0; i < num_blocks; ++i) { uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size; + uint32_t roundup_gws1 = RoundUp(gws1, params[1]); error = runtime->command_queue().enqueueNDRangeKernel( - kernel, cl::NDRange(0, i * block_size), cl::NDRange(gws[0], gws1), + kernel, cl::NDRange(0, i * block_size), cl::NDRange(roundup_gws[0], roundup_gws1), cl::NDRange(params[0], params[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; timer->AccumulateTiming(); diff --git a/mace/kernels/opencl/matmul.cc b/mace/kernels/opencl/matmul.cc index c5bd2b0b..3609b1a6 100644 --- a/mace/kernels/opencl/matmul.cc +++ b/mace/kernels/opencl/matmul.cc @@ -26,9 +26,14 @@ void MatMulFunctor::operator()(const Tensor *A, const index_t height_blocks = RoundUpDiv4(height); const index_t width_blocks = RoundUpDiv4(width); + const uint32_t gws[2] = { + static_cast(width_blocks), + static_cast(height_blocks * batch), + }; + + auto runtime = OpenCLRuntime::Global(); if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("matmul"); @@ -46,12 +51,12 @@ void MatMulFunctor::operator()(const Tensor *A, kernel_.setArg(idx++, static_cast(A->dim(2))); kernel_.setArg(idx++, static_cast(height_blocks)); kernel_.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); - const uint32_t gws[2] = { - static_cast(width_blocks), - static_cast(height_blocks * batch), - }; - const std::vector lws = {16, 64, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {kwg_size / 64, 64, 1}; std::stringstream ss; ss << "matmul_opencl_kernel_" << C->dim(0) << "_" << C->dim(1) << "_" << C->dim(2) << "_" << C->dim(3); diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index 5b52a093..4e97174e 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -18,9 +18,10 @@ void PoolingFunctor::operator()(const Tensor *input, MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1) << "Pooling opencl kernel not support dilation yet"; + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { const DataType dt = DataTypeToEnum::value; - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("pooling"); built_options.emplace("-Dpooling=" + kernel_name); @@ -37,6 +38,8 @@ void PoolingFunctor::operator()(const Tensor *input, } kernel_ = runtime->BuildKernel("pooling", kernel_name, built_options); } + + uint32_t gws[3]; if (!IsVecEqual(input_shape_, input->shape())) { std::vector output_shape(4); std::vector filter_shape = {kernels_[0], kernels_[1], @@ -59,6 +62,17 @@ void PoolingFunctor::operator()(const Tensor *input, &output_image_shape); output->ResizeImage(output_shape, output_image_shape); + index_t batch = output->dim(0); + index_t out_height = output->dim(1); + index_t out_width = output->dim(2); + index_t channels = output->dim(3); + + index_t channel_blocks = (channels + 3) / 4; + + gws[0] = static_cast(channel_blocks); + gws[1] = static_cast(out_width); + gws[2] = static_cast(batch * out_height); + uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(input->dim(1))); @@ -69,23 +83,27 @@ void PoolingFunctor::operator()(const Tensor *input, kernel_.setArg(idx++, strides_[0]); kernel_.setArg(idx++, kernels_[0]); kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); - } + } else { + index_t batch = output->dim(0); + index_t out_height = output->dim(1); + index_t out_width = output->dim(2); + index_t channels = output->dim(3); - index_t batch = output->dim(0); - index_t out_height = output->dim(1); - index_t out_width = output->dim(2); - index_t channels = output->dim(3); - - index_t channel_blocks = (channels + 3) / 4; + index_t channel_blocks = (channels + 3) / 4; + gws[0] = static_cast(channel_blocks); + gws[1] = static_cast(out_width); + gws[2] = static_cast(batch * out_height); + } - const uint32_t gws[3] = { - static_cast(channel_blocks), static_cast(out_width), - static_cast(batch * out_height), - }; - std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "pooling_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 37370916..d6a18519 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -24,8 +24,13 @@ void ResizeBilinearFunctor::operator()( const index_t out_height = out_height_; const index_t out_width = out_width_; + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(out_width), + static_cast(out_height * batch)}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("resize_bilinear_nocache"); built_options.emplace("-Dresize_bilinear_nocache=" + kernel_name); @@ -57,14 +62,16 @@ void ResizeBilinearFunctor::operator()( kernel_.setArg(idx++, static_cast(in_height)); kernel_.setArg(idx++, static_cast(in_width)); kernel_.setArg(idx++, static_cast(out_height)); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(out_width), - static_cast(out_height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "resize_bilinear_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); diff --git a/mace/kernels/opencl/slice.cc b/mace/kernels/opencl/slice.cc index 6bc9ae3b..f4e39089 100644 --- a/mace/kernels/opencl/slice.cc +++ b/mace/kernels/opencl/slice.cc @@ -29,8 +29,9 @@ void SliceFunctor::operator()( output_list[i]->ResizeImage(output_shape, image_shape); } + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); built_options.emplace("-Dslice=" + kernel_name); @@ -46,7 +47,10 @@ void SliceFunctor::operator()( static_cast(input->dim(2)), static_cast(input->dim(0) * input->dim(1)), }; - const std::vector lws = {8, 16, 8, 1}; + + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "slice_opencl_kernel_" << input->dim(0) << "_" @@ -59,6 +63,9 @@ void SliceFunctor::operator()( kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(channel_blk * i)); kernel_.setArg(idx++, *(output_list[i]->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); } diff --git a/mace/kernels/opencl/softmax_opencl.cc b/mace/kernels/opencl/softmax_opencl.cc index 077db9dd..3ec6447a 100644 --- a/mace/kernels/opencl/softmax_opencl.cc +++ b/mace/kernels/opencl/softmax_opencl.cc @@ -23,9 +23,12 @@ void SoftmaxFunctor::operator()(const Tensor *logits, const index_t channel_blocks = RoundUpDiv4(channels); const int remain_channels = channel_blocks * 4 - channels; - if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { std::set built_options; std::string kernel_name = MACE_OBFUSCATE_SYMBOL("softmax"); built_options.emplace("-Dsoftmax=" + kernel_name); @@ -40,12 +43,15 @@ void SoftmaxFunctor::operator()(const Tensor *logits, kernel_.setArg(idx++, static_cast(channels)); kernel_.setArg(idx++, remain_channels); kernel_.setArg(idx++, *(output->opencl_image())); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); input_shape_ = logits->shape(); } - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8, 1}; + + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << "softmax_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); diff --git a/mace/kernels/opencl/space_to_batch_opencl.cc b/mace/kernels/opencl/space_to_batch_opencl.cc index fe911fbd..b2de2748 100644 --- a/mace/kernels/opencl/space_to_batch_opencl.cc +++ b/mace/kernels/opencl/space_to_batch_opencl.cc @@ -31,9 +31,15 @@ void SpaceToBatchFunctor::operator()( batch_tensor->ResizeImage(output_shape, output_image_shape); kernel_name = "space_to_batch"; } + const uint32_t chan_blk = RoundUpDiv4(batch_tensor->dim(3)); + const uint32_t gws[3] = { + chan_blk, static_cast(batch_tensor->dim(2)), + static_cast(batch_tensor->dim(0) * batch_tensor->dim(1))}; + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name); - auto runtime = OpenCLRuntime::Global(); std::set built_options; std::stringstream kernel_name_ss; kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name; @@ -61,15 +67,16 @@ void SpaceToBatchFunctor::operator()( kernel_.setArg(idx++, static_cast(space_tensor->dim(2))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(1))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(2))); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); space_shape_ = space_tensor->shape(); } - const uint32_t chan_blk = RoundUpDiv4(batch_tensor->dim(3)); - const uint32_t gws[3] = { - chan_blk, static_cast(batch_tensor->dim(2)), - static_cast(batch_tensor->dim(0) * batch_tensor->dim(1))}; - const std::vector lws = {8, 16, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {8, kwg_size / 64, 8, 1}; std::stringstream ss; ss << kernel_name << "_" << batch_tensor->dim(0) << "_" << batch_tensor->dim(1) << "_" << batch_tensor->dim(2) << "_" diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index 3b866408..b3f4889b 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -15,6 +15,8 @@ template void WinogradTransformFunctor::operator()( const Tensor *input_tensor, Tensor *output_tensor, StatsFuture *future) { + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2"); @@ -24,7 +26,6 @@ void WinogradTransformFunctor::operator()( DtToUpstreamCLDt(DataTypeToEnum::value)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum::value)); - auto runtime = OpenCLRuntime::Global(); kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name, built_options); } @@ -44,6 +45,9 @@ void WinogradTransformFunctor::operator()( const index_t round_h = (output_shape[1] + 1) / 2; const index_t round_w = (output_shape[2] + 1) / 2; const index_t out_width = input_tensor->dim(0) * round_h * round_w; + const uint32_t gws[2] = { + static_cast(out_width), + static_cast(RoundUpDiv4(input_tensor->dim(3)))}; if (!IsVecEqual(input_shape_, input_tensor->shape())) { output_shape = {16, input_tensor->dim(3), out_width, 1}; @@ -61,14 +65,15 @@ void WinogradTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, static_cast(paddings[0] / 2)); kernel_.setArg(idx++, static_cast(paddings[1] / 2)); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensor->shape(); } - const uint32_t gws[2] = { - static_cast(out_width), - static_cast(RoundUpDiv4(input_tensor->dim(3)))}; - const std::vector lws = {128, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {kwg_size / 8, 8, 1}; std::stringstream ss; ss << "winograd_transform_kernel_" << input_tensor->dim(0) << "_" << input_tensor->dim(1) << "_" << input_tensor->dim(2) << "_" @@ -82,6 +87,9 @@ void WinogradInverseTransformFunctor::operator()( const Tensor *bias, Tensor *output_tensor, StatsFuture *future) { + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL("winograd_inverse_transform_2x2"); @@ -115,10 +123,13 @@ void WinogradInverseTransformFunctor::operator()( LOG(FATAL) << "Unknown activation type: " << activation_; } - auto runtime = OpenCLRuntime::Global(); kernel_ = runtime->BuildKernel("winograd_transform", obfuscated_kernel_name, built_options); } + + const uint32_t gws[2] = { + static_cast(input_tensor->dim(2)), + static_cast(RoundUpDiv4(input_tensor->dim(1)))}; if (!IsVecEqual(input_shape_, input_tensor->shape())) { std::vector output_shape = {batch_, height_, width_, input_tensor->dim(1)}; @@ -143,14 +154,15 @@ void WinogradInverseTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_h * round_w)); kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, relux_max_limit_); + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensor->shape(); } - const uint32_t gws[2] = { - static_cast(input_tensor->dim(2)), - static_cast(RoundUpDiv4(input_tensor->dim(1)))}; - const std::vector lws = {128, 8, 1}; + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + const std::vector lws = {kwg_size / 8, 8, 1}; std::stringstream ss; ss << "winograd_inverse_transform_kernel_" << input_tensor->dim(0) << "_" -- GitLab