From eefc9d01bcd59cf541c9a438266d66fecdc8dc25 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Thu, 29 Mar 2018 12:50:21 +0800 Subject: [PATCH] update cl source code format, fix pooling op core dump --- mace/kernels/opencl/activation_opencl.cc | 8 +- mace/kernels/opencl/addn.cc | 6 +- mace/kernels/opencl/batch_norm_opencl.cc | 8 +- mace/kernels/opencl/bias_add_opencl.cc | 8 +- mace/kernels/opencl/buffer_to_image.cc | 6 +- mace/kernels/opencl/channel_shuffle.cc | 8 +- mace/kernels/opencl/cl/activation.cl | 16 +- mace/kernels/opencl/cl/addn.cl | 14 +- mace/kernels/opencl/cl/batch_norm.cl | 16 +- mace/kernels/opencl/cl/bias_add.cl | 11 +- mace/kernels/opencl/cl/buffer_to_image.cl | 168 ++++++++---------- mace/kernels/opencl/cl/channel_shuffle.cl | 18 +- mace/kernels/opencl/cl/concat.cl | 26 ++- mace/kernels/opencl/cl/conv_2d.cl | 16 +- mace/kernels/opencl/cl/conv_2d_1x1.cl | 16 +- mace/kernels/opencl/cl/conv_2d_3x3.cl | 16 +- mace/kernels/opencl/cl/depth_to_space.cl | 27 ++- mace/kernels/opencl/cl/depthwise_conv2d.cl | 32 ++-- mace/kernels/opencl/cl/eltwise.cl | 14 +- mace/kernels/opencl/cl/matmul.cl | 14 +- mace/kernels/opencl/cl/pooling.cl | 15 +- mace/kernels/opencl/cl/resize_bilinear.cl | 15 +- mace/kernels/opencl/cl/slice.cl | 12 +- mace/kernels/opencl/cl/softmax.cl | 14 +- mace/kernels/opencl/cl/space_to_batch.cl | 32 ++-- mace/kernels/opencl/cl/winograd_transform.cl | 28 ++- mace/kernels/opencl/concat.cc | 16 +- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 8 +- mace/kernels/opencl/conv_2d_opencl_3x3.cc | 8 +- mace/kernels/opencl/conv_2d_opencl_general.cc | 8 +- mace/kernels/opencl/depth_to_space_opencl.cc | 19 +- mace/kernels/opencl/depthwise_conv_opencl.cc | 8 +- mace/kernels/opencl/eltwise_opencl.cc | 6 +- mace/kernels/opencl/matmul.cc | 6 +- mace/kernels/opencl/pooling_opencl.cc | 20 ++- mace/kernels/opencl/resize_bilinear_opencl.cc | 8 +- mace/kernels/opencl/slice.cc | 8 +- mace/kernels/opencl/softmax_opencl.cc | 8 +- mace/kernels/opencl/space_to_batch_opencl.cc | 8 +- mace/kernels/opencl/winograd_transform.cc | 12 +- tools/mace_tools.py | 3 +- tools/packaging_lib.sh | 7 +- 42 files changed, 355 insertions(+), 362 deletions(-) diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index d3e6c7f9..f41513c5 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -71,6 +71,11 @@ void ActivationFunctor::operator()(const Tensor *input, if (!IsVecEqual(input_shape_, input->shape())) { int idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); if (activation_ == PRELU) { MACE_CHECK_NOTNULL(alpha); @@ -78,9 +83,6 @@ void ActivationFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, static_cast(relux_max_limit_)); kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index d7c149a9..c2c19fa7 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -70,12 +70,14 @@ void AddNFunctor::operator()( output_tensor->ResizeImage(output_shape, output_image_shape); uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + } for (auto input : input_tensors) { kernel_.setArg(idx++, *(input->opencl_image())); } kernel_.setArg(idx++, *(output_tensor->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensors[0]->shape(); diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index c3a1765c..8065acba 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -75,6 +75,11 @@ void BatchNormFunctor::operator()(const Tensor *input, } if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(scale->opencl_image())); kernel_.setArg(idx++, *(offset->opencl_image())); @@ -85,9 +90,6 @@ void BatchNormFunctor::operator()(const Tensor *input, } kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, relux_max_limit_); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/bias_add_opencl.cc b/mace/kernels/opencl/bias_add_opencl.cc index e67ebe71..a518f074 100644 --- a/mace/kernels/opencl/bias_add_opencl.cc +++ b/mace/kernels/opencl/bias_add_opencl.cc @@ -45,12 +45,14 @@ void BiasAddFunctor::operator()(const Tensor *input, } if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(bias->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); kwg_size_ = diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 565b3d56..5d7ae4c1 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -87,6 +87,10 @@ void BufferToImageFunctor::operator()( obfuscated_kernel_name, built_options); uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported) { + b2f_kernel.setArg(idx++, gws[0]); + b2f_kernel.setArg(idx++, gws[1]); + } b2f_kernel.setArg(idx++, *(buffer->opencl_buffer())); if (!i2b_) { MACE_CHECK(buffer->buffer_offset() % GetEnumTypeSize(buffer->dtype()) == 0, @@ -112,8 +116,6 @@ void BufferToImageFunctor::operator()( b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); } b2f_kernel.setArg(idx++, *(image->opencl_image())); - b2f_kernel.setArg(idx++, gws[0]); - b2f_kernel.setArg(idx++, gws[1]); const uint32_t kwg_size = static_cast(runtime->GetKernelMaxWorkGroupSize(b2f_kernel)); diff --git a/mace/kernels/opencl/channel_shuffle.cc b/mace/kernels/opencl/channel_shuffle.cc index 316ae62a..29097345 100644 --- a/mace/kernels/opencl/channel_shuffle.cc +++ b/mace/kernels/opencl/channel_shuffle.cc @@ -54,13 +54,15 @@ void ChannelShuffleFunctor::operator()( if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, groups_); kernel_.setArg(idx++, static_cast(channels_per_group)); kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/cl/activation.cl b/mace/kernels/opencl/cl/activation.cl index a02b0e35..7976dd38 100644 --- a/mace/kernels/opencl/cl/activation.cl +++ b/mace/kernels/opencl/cl/activation.cl @@ -1,19 +1,17 @@ #include -__kernel void activation(__read_only image2d_t input, +__kernel void activation( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, #ifdef USE_PRELU __read_only image2d_t alpha, #endif __private const float relux_max_limit, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __write_only image2d_t output) { -#endif - const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/addn.cl b/mace/kernels/opencl/cl/addn.cl index 23e47e50..09dd5c38 100644 --- a/mace/kernels/opencl/cl/addn.cl +++ b/mace/kernels/opencl/cl/addn.cl @@ -1,6 +1,11 @@ #include -__kernel void addn(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ +__kernel void addn( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t input1, #if INPUT_NUM > 2 __read_only image2d_t input2, @@ -8,14 +13,7 @@ __kernel void addn(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ #if INPUT_NUM > 3 __read_only image2d_t input3, #endif -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - const int w = get_global_id(0); const int hb = get_global_id(1); diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index d36c1e8b..0e592fdc 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -1,6 +1,12 @@ #include // Supported data types: half/float -__kernel void batch_norm(__read_only image2d_t input, +__kernel void batch_norm( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset, #ifndef FOLDED_CONSTANT @@ -9,15 +15,7 @@ __kernel void batch_norm(__read_only image2d_t input, __private const float epsilon, #endif __write_only image2d_t output, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const float relux_max_limit, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const float relux_max_limit) { -#endif - const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/bias_add.cl b/mace/kernels/opencl/cl/bias_add.cl index 594528ce..ee7b6078 100644 --- a/mace/kernels/opencl/cl/bias_add.cl +++ b/mace/kernels/opencl/cl/bias_add.cl @@ -1,15 +1,14 @@ #include // Supported data types: half/float -__kernel void bias_add(__read_only image2d_t input, - __read_only image2d_t bias, +__kernel void bias_add( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif + __read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index a5d9f289..7e764503 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -1,19 +1,17 @@ #include -__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, oc, ic */ +__kernel void filter_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, /* h, w, oc, ic */ __private const int input_offset, __private const int filter_h, __private const int filter_w, __private const int out_channel, __private const int in_channel, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -58,19 +56,17 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, o WRITE_IMAGET(output, coord, values); } -__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic */ +__kernel void filter_image_to_buffer( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global DATA_TYPE *output, /* h, w, oc, ic */ __private const int filter_h, __private const int filter_w, __private const int out_channel, __private const int in_channel, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __read_only image2d_t input, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __read_only image2d_t input) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -112,19 +108,17 @@ __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic } } -__kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, ic, m */ +__kernel void dw_filter_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, /* h, w, ic, m */ __private const int input_offset, __private const int filter_w, __private const int in_channel, __private const int multiplier, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { /* ic%4 * kh * kw * m, ic/4 */ -#else - __write_only image2d_t output) { -#endif - + __write_only image2d_t output) { /* ic%4 * kh * kw * m, ic/4 */ const int w = get_global_id(0); const int h = get_global_id(1); @@ -175,19 +169,17 @@ __kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w WRITE_IMAGET(output, coord, values); } -__kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ +__kernel void in_out_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, /* nhwc */ __private const int input_offset, __private const int height, __private const int width, __private const int channels, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -222,18 +214,16 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ WRITE_IMAGET(output, coord, values); } -__kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ +__kernel void in_out_image_to_buffer( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global DATA_TYPE *output, /* nhwc */ __private const int height, __private const int width, __private const int channels, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __read_only image2d_t input, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __read_only image2d_t input) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -267,17 +257,15 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ } } -__kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ - __private const int input_offset, - __private const int count, +__kernel void arg_buffer_to_image( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, - __private const int global_size_dim1) { -#else - __write_only image2d_t output) { + __private const int global_size_dim1, #endif - + __global const DATA_TYPE *input, /* nhwc */ + __private const int input_offset, + __private const int count, + __write_only image2d_t output) { int w = get_global_id(0); int h = get_global_id(1); @@ -308,16 +296,14 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ WRITE_IMAGET(output, coord, values); } -__kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ - __private const int count, +__kernel void arg_image_to_buffer( #ifndef USE_QUALCOMM_OPENCL_2_0 - __read_only image2d_t input, __private const int global_size_dim0, - __private const int global_size_dim1) { -#else - __read_only image2d_t input) { + __private const int global_size_dim1, #endif - + __global DATA_TYPE *output, /* nhwc */ + __private const int count, + __read_only image2d_t input) { int w = get_global_id(0); int h = get_global_id(1); @@ -347,19 +333,17 @@ __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ } -__kernel void in_out_height_buffer_to_image(__global const DATA_TYPE *input, //nhwc +__kernel void in_out_height_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, //nhwc __private const int input_offset, __private const int height, __private const int width, __private const int channels, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -395,18 +379,16 @@ __kernel void in_out_height_buffer_to_image(__global const DATA_TYPE *input, //n WRITE_IMAGET(output, coord, values); } -__kernel void in_out_height_image_to_buffer(__global DATA_TYPE *output, //nhwc +__kernel void in_out_height_image_to_buffer( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global DATA_TYPE *output, //nhwc __private const int height, __private const int width, __private const int channels, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __read_only image2d_t input, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __read_only image2d_t input) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -439,19 +421,17 @@ __kernel void in_out_height_image_to_buffer(__global DATA_TYPE *output, //nhwc } -__kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ +__kernel void in_out_width_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, /* nhwc */ __private const int input_offset, __private const int height, __private const int width, __private const int channels, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -487,19 +467,17 @@ __kernel void in_out_width_buffer_to_image(__global const DATA_TYPE *input, /* n } // only support 3x3 now -__kernel void winograd_filter_buffer_to_image(__global const DATA_TYPE *input, //Oc, Ic, H, W +__kernel void winograd_filter_buffer_to_image( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global const DATA_TYPE *input, //Oc, Ic, H, W __private const int input_offset, __private const int in_channels, __private const int height, __private const int width, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - int w = get_global_id(0); int h = get_global_id(1); @@ -584,18 +562,16 @@ __kernel void winograd_filter_buffer_to_image(__global const DATA_TYPE *input, / } // only support 3x3 now -__kernel void winograd_filter_image_to_buffer(__global DATA_TYPE *output, //Oc, Ic, H, W +__kernel void winograd_filter_image_to_buffer( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __global DATA_TYPE *output, //Oc, Ic, H, W __private const int height, __private const int width, __private const int channel, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __read_only image2d_t input, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __read_only image2d_t input) { -#endif - const int w = get_global_id(0); const int h = get_global_id(1); diff --git a/mace/kernels/opencl/cl/channel_shuffle.cl b/mace/kernels/opencl/cl/channel_shuffle.cl index 87159784..5bf0e067 100644 --- a/mace/kernels/opencl/cl/channel_shuffle.cl +++ b/mace/kernels/opencl/cl/channel_shuffle.cl @@ -1,18 +1,16 @@ #include // assume channes_per_group mod 4 = 0 && groups mod 4 == 0 -__kernel void channel_shuffle(__read_only image2d_t input, - __private const int groups, - __private const int channels_per_group, +__kernel void channel_shuffle( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, #endif - + __read_only image2d_t input, + __private const int groups, + __private const int channels_per_group, + __write_only image2d_t output) { const int group_chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); diff --git a/mace/kernels/opencl/cl/concat.cl b/mace/kernels/opencl/cl/concat.cl index c8bfebaa..2658025d 100644 --- a/mace/kernels/opencl/cl/concat.cl +++ b/mace/kernels/opencl/cl/concat.cl @@ -22,18 +22,16 @@ DATA_TYPE4 stitch_vector(DATA_TYPE4 left, } // Supported data type: half/float -__kernel void concat_channel(__read_only image2d_t input0, - __read_only image2d_t input1, - __private const int input0_chan, +__kernel void concat_channel( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif - + __read_only image2d_t input0, + __read_only image2d_t input1, + __private const int input0_chan, + __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); @@ -90,17 +88,15 @@ __kernel void concat_channel(__read_only image2d_t input0, } // Required: All input channels are divisible by 4 -__kernel void concat_channel_multi(__read_only image2d_t input, - __private const int chan_blk_offset, +__kernel void concat_channel_multi( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif - + __read_only image2d_t input, + __private const int chan_blk_offset, + __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 8fa23f02..f40f31da 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -1,6 +1,12 @@ #include -__kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void conv_2d( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * cin, kh * kw * cout/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ @@ -18,15 +24,7 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __private const int padding_top, __private const int padding_left, __private const int dilation_h, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int dilation_w, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int dilation_w) { -#endif - const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); const int out_hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index 70d88867..96d9a2c0 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -1,6 +1,12 @@ #include -__kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void conv_2d_1x1( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * cin, cout/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ @@ -12,15 +18,7 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __private const int in_ch_blks, __private const int height, __private const int width, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int stride, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int stride) { -#endif - const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); const int out_hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index 8ce485b7..b159fd6a 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -1,6 +1,12 @@ #include -__kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void conv_2d_3x3( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * cin , kh * kw * cout/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ @@ -16,15 +22,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] __private const int padding_top, __private const int padding_left, __private const int dilation_h, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int dilation_w, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int dilation_w) { -#endif - const int out_ch_blk = get_global_id(0); const int out_w_blk = get_global_id(1); const int out_hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl index 349b665d..2a5a8893 100644 --- a/mace/kernels/opencl/cl/depth_to_space.cl +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -1,17 +1,15 @@ #include -__kernel void depth_to_space(__read_only image2d_t input, - __private const int block_size, - __private const int output_depth, +__kernel void depth_to_space( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif - + __read_only image2d_t input, + __private const int block_size, + __private const int output_depth, + __write_only image2d_t output) { const int out_d = get_global_id(0); const int out_w = get_global_id(1); const int out_h = get_global_id(2); @@ -44,17 +42,16 @@ __kernel void depth_to_space(__read_only image2d_t input, WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); } -__kernel void space_to_depth(__read_only image2d_t input, - __private const int block_size, - __private const int input_depth, +__kernel void space_to_depth( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif + __read_only image2d_t input, + __private const int block_size, + __private const int input_depth, + __write_only image2d_t output) { const int d = get_global_id(0); const int w = get_global_id(1); diff --git a/mace/kernels/opencl/cl/depthwise_conv2d.cl b/mace/kernels/opencl/cl/depthwise_conv2d.cl index 7d39d3c1..1974d8db 100644 --- a/mace/kernels/opencl/cl/depthwise_conv2d.cl +++ b/mace/kernels/opencl/cl/depthwise_conv2d.cl @@ -1,7 +1,13 @@ #include // Only multiplier = 1 is supported -__kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void depthwise_conv2d( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ @@ -18,15 +24,7 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h __private const short padding_top, __private const short padding_left, __private const short dilation_h, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const short dilation_w, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const short dilation_w) { -#endif - const short out_ch_blk = get_global_id(0); const short out_w_blk = get_global_id(1); const short out_hb = get_global_id(2); @@ -144,7 +142,13 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out3); } -__kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void depthwise_conv2d_s1( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ @@ -159,15 +163,7 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 __private const short filter_height, __private const short filter_width, __private const short padding_top, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const short padding_left, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const short padding_left) { -#endif - const short out_ch_blk = get_global_id(0); const short out_w_blk = get_global_id(1) << 2; const short out_hb = get_global_id(2); diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index d7c90e03..9a69af1a 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -1,19 +1,17 @@ #include -__kernel void eltwise(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ +__kernel void eltwise( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ __read_only image2d_t input1, #ifdef COEFF_SUM __private const float coeff0, __private const float coeff1, #endif -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __write_only image2d_t output) { -#endif - const int w = get_global_id(0); const int hb = get_global_id(1); diff --git a/mace/kernels/opencl/cl/matmul.cl b/mace/kernels/opencl/cl/matmul.cl index 7107838c..c3efc9f2 100644 --- a/mace/kernels/opencl/cl/matmul.cl +++ b/mace/kernels/opencl/cl/matmul.cl @@ -1,21 +1,19 @@ #include // C = A * B -__kernel void matmul(__read_only image2d_t A, +__kernel void matmul( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __read_only image2d_t A, __read_only image2d_t B, __write_only image2d_t C, __private const int M, __private const int N, __private const int K, __private const int height_blocks, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int k_blocks, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __private const int k_blocks) { -#endif - const int gx = get_global_id(0) << 2; const int hb = get_global_id(1); diff --git a/mace/kernels/opencl/cl/pooling.cl b/mace/kernels/opencl/cl/pooling.cl index 8cdc4e46..0a28b745 100644 --- a/mace/kernels/opencl/cl/pooling.cl +++ b/mace/kernels/opencl/cl/pooling.cl @@ -19,7 +19,13 @@ inline int calculate_avg_block_size(const int pool_size, } // Supported data type: half/float -__kernel void pooling(__read_only image2d_t input, +__kernel void pooling( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, __private const int in_height, __private const int in_width, __private const int out_height, @@ -27,14 +33,7 @@ __kernel void pooling(__read_only image2d_t input, __private const int pad_left, __private const int stride, __private const int pooling_size, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __write_only image2d_t output) { -#endif const int out_chan_idx = get_global_id(0); const int out_width_idx = get_global_id(1); diff --git a/mace/kernels/opencl/cl/resize_bilinear.cl b/mace/kernels/opencl/cl/resize_bilinear.cl index 5369c762..8c5b7a33 100644 --- a/mace/kernels/opencl/cl/resize_bilinear.cl +++ b/mace/kernels/opencl/cl/resize_bilinear.cl @@ -1,19 +1,18 @@ #include -__kernel void resize_bilinear_nocache(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ +__kernel void resize_bilinear_nocache( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ __write_only image2d_t output, __private const float height_scale, __private const float width_scale, __private const int in_height, __private const int in_width, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int out_height, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int out_height) { -#endif const int ch_blk = get_global_id(0); const int w = get_global_id(1); diff --git a/mace/kernels/opencl/cl/slice.cl b/mace/kernels/opencl/cl/slice.cl index bb5f40cd..4517ec99 100644 --- a/mace/kernels/opencl/cl/slice.cl +++ b/mace/kernels/opencl/cl/slice.cl @@ -1,16 +1,14 @@ #include -__kernel void slice(__read_only image2d_t input, - __private const int chan_blk_offset, +__kernel void slice( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif - + __read_only image2d_t input, + __private const int chan_blk_offset, + __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); diff --git a/mace/kernels/opencl/cl/softmax.cl b/mace/kernels/opencl/cl/softmax.cl index 3fadd18e..11ff80bf 100644 --- a/mace/kernels/opencl/cl/softmax.cl +++ b/mace/kernels/opencl/cl/softmax.cl @@ -1,17 +1,15 @@ #include -__kernel void softmax(__read_only image2d_t input, - __private const int channels, - __private const int remain_channels, +__kernel void softmax( #ifndef USE_QUALCOMM_OPENCL_2_0 - __write_only image2d_t output, __private const int global_size_dim0, __private const int global_size_dim1, - __private const int global_size_dim2) { -#else - __write_only image2d_t output) { + __private const int global_size_dim2, #endif - + __read_only image2d_t input, + __private const int channels, + __private const int remain_channels, + __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); diff --git a/mace/kernels/opencl/cl/space_to_batch.cl b/mace/kernels/opencl/cl/space_to_batch.cl index 822d0906..0a546012 100644 --- a/mace/kernels/opencl/cl/space_to_batch.cl +++ b/mace/kernels/opencl/cl/space_to_batch.cl @@ -1,6 +1,12 @@ #include -__kernel void space_to_batch(__read_only image2d_t space_data, +__kernel void space_to_batch( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t space_data, __write_only image2d_t batch_data, __private const int block_height, __private const int block_width, @@ -9,15 +15,7 @@ __kernel void space_to_batch(__read_only image2d_t space_data, __private const int space_height, __private const int space_width, __private const int batch_height, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int batch_width, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int batch_width) { -#endif - const int chan_idx = get_global_id(0); const int batch_w_idx = get_global_id(1); const int batch_hb_idx = get_global_id(2); @@ -54,7 +52,13 @@ __kernel void space_to_batch(__read_only image2d_t space_data, WRITE_IMAGET(batch_data, batch_coord, value); } -__kernel void batch_to_space(__read_only image2d_t batch_data, +__kernel void batch_to_space( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, +#endif + __read_only image2d_t batch_data, __write_only image2d_t space_data, __private const int block_height, __private const int block_width, @@ -63,15 +67,7 @@ __kernel void batch_to_space(__read_only image2d_t batch_data, __private const int space_height, __private const int space_width, __private const int batch_height, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int batch_width, - __private const int global_size_dim0, - __private const int global_size_dim1, - __private const int global_size_dim2) { -#else __private const int batch_width) { -#endif - const int chan_idx = get_global_id(0); const int batch_w_idx = get_global_id(1); const int batch_hb_idx = get_global_id(2); diff --git a/mace/kernels/opencl/cl/winograd_transform.cl b/mace/kernels/opencl/cl/winograd_transform.cl index 098c8e3b..f3f99cfa 100644 --- a/mace/kernels/opencl/cl/winograd_transform.cl +++ b/mace/kernels/opencl/cl/winograd_transform.cl @@ -1,6 +1,11 @@ #include -__kernel void winograd_transform_2x2(__read_only image2d_t input, +__kernel void winograd_transform_2x2( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __read_only image2d_t input, __write_only image2d_t output, __private const int in_height, __private const int in_width, @@ -8,14 +13,7 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input, __private const int round_hw, __private const int round_w, __private const int padding_top, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const int padding_left, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __private const int padding_left) { -#endif - int out_width_idx = get_global_id(0); int chan_blk_idx = get_global_id(1); @@ -121,7 +119,12 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input, } } -__kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, +__kernel void winograd_inverse_transform_2x2( +#ifndef USE_QUALCOMM_OPENCL_2_0 + __private const int global_size_dim0, + __private const int global_size_dim1, +#endif + __read_only image2d_t input, #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif @@ -130,14 +133,7 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, __private const int out_width, __private const int round_hw, __private const int round_w, -#ifndef USE_QUALCOMM_OPENCL_2_0 - __private const float relux_max_limit, - __private const int global_size_dim0, - __private const int global_size_dim1) { -#else __private const float relux_max_limit) { -#endif - const int width_idx = get_global_id(0); const int height_idx = get_global_id(1); diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index 111b7a9c..56449d14 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -56,6 +56,11 @@ static void Concat2(cl::Kernel *kernel, } if (!IsVecEqual(*prev_input_shape, input0->shape())) { uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(static_cast(input0->opencl_image()))); kernel->setArg(idx++, @@ -63,9 +68,6 @@ static void Concat2(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(input0->dim(3))); kernel->setArg(idx++, *(static_cast(output->opencl_image()))); - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); *prev_input_shape = input0->shape(); @@ -119,12 +121,14 @@ static void ConcatN(cl::Kernel *kernel, }; uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, static_cast(chan_blk_offset)); kernel->setArg(idx++, *(output->opencl_image())); - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); chan_blk_offset += input_channel_blk; *kwg_size = diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index be2fd08b..34055189 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -84,6 +84,11 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); if (bias != nullptr) { @@ -98,9 +103,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(height)); kernel->setArg(idx++, static_cast(width)); kernel->setArg(idx++, stride); - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index cec0927f..88793dac 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -79,6 +79,11 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); if (bias != nullptr) { @@ -96,9 +101,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index a9151b48..19132209 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -79,6 +79,11 @@ extern void Conv2dOpencl(cl::Kernel *kernel, if (!IsVecEqual(*prev_input_shape, input->shape())) { uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); if (bias != nullptr) { @@ -98,9 +103,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[1]); - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc index 4daeac61..83cff273 100644 --- a/mace/kernels/opencl/depth_to_space_opencl.cc +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -68,12 +68,6 @@ void DepthToSpaceOpFunctor::operator()( uint32_t gws[3]; std::stringstream ss; if (!IsVecEqual(input_shape_, input->shape())) { - uint32_t idx = 0; - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, block_size_); - kernel_.setArg(idx++, depth_blocks); - kernel_.setArg(idx++, *(output->opencl_image())); - if (d2s_) { gws[0] = static_cast(depth_blocks); gws[1] = static_cast(output_width); @@ -88,9 +82,16 @@ void DepthToSpaceOpFunctor::operator()( << input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3); } - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); + uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, block_size_); + kernel_.setArg(idx++, depth_blocks); + kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 873a16a4..11bb38b3 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -97,6 +97,11 @@ void DepthwiseConv2d(cl::Kernel *kernel, input_channels); uint32_t idx = 0; + if (!(*is_non_uniform_work_groups_supported)) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } kernel->setArg(idx++, *(input->opencl_image())); kernel->setArg(idx++, *(filter->opencl_image())); if (bias != nullptr) { @@ -117,9 +122,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, kernel->setArg(idx++, static_cast(dilations[0])); kernel->setArg(idx++, static_cast(dilations[1])); } - kernel->setArg(idx++, gws[0]); - kernel->setArg(idx++, gws[1]); - kernel->setArg(idx++, gws[2]); *prev_input_shape = input->shape(); diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 13413130..38a231da 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -47,6 +47,10 @@ void EltwiseFunctor::operator()(const Tensor *input0, } if (!IsVecEqual(input_shape_, input0->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + } kernel_.setArg(idx++, *(input0->opencl_image())); kernel_.setArg(idx++, *(input1->opencl_image())); if (!coeff_.empty()) { @@ -54,8 +58,6 @@ void EltwiseFunctor::operator()(const Tensor *input0, kernel_.setArg(idx++, coeff_[1]); } kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); input_shape_ = input0->shape(); diff --git a/mace/kernels/opencl/matmul.cc b/mace/kernels/opencl/matmul.cc index 19769f3d..79dcc40d 100644 --- a/mace/kernels/opencl/matmul.cc +++ b/mace/kernels/opencl/matmul.cc @@ -48,6 +48,10 @@ void MatMulFunctor::operator()(const Tensor *A, kernel_ = runtime->BuildKernel("matmul", kernel_name, built_options); } uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + } kernel_.setArg(idx++, *(A->opencl_image())); kernel_.setArg(idx++, *(B->opencl_image())); kernel_.setArg(idx++, *(C->opencl_image())); @@ -56,8 +60,6 @@ void MatMulFunctor::operator()(const Tensor *A, kernel_.setArg(idx++, static_cast(A->dim(2))); kernel_.setArg(idx++, static_cast(height_blocks)); kernel_.setArg(idx++, static_cast(RoundUpDiv4(A->dim(2)))); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); kwg_size_ = static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index fa9e1577..9b2f96c8 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -81,6 +81,11 @@ void PoolingFunctor::operator()(const Tensor *input, }; uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(input->dim(1))); kernel_.setArg(idx++, static_cast(input->dim(2))); @@ -90,14 +95,23 @@ void PoolingFunctor::operator()(const Tensor *input, kernel_.setArg(idx++, strides_[0]); kernel_.setArg(idx++, kernels_[0]); kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); kwg_size_ = static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + } else { + index_t batch = output->dim(0); + index_t out_height = output->dim(1); + index_t out_width = output->dim(2); + index_t channels = output->dim(3); + + index_t channel_blocks = (channels + 3) / 4; + + gws = { + static_cast(channel_blocks), static_cast(out_width), + static_cast(batch * out_height), + }; } std::vector lws = {8, kwg_size_ / 64, 8, 1}; diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 5bcb53e3..ce2fe7bf 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -60,6 +60,11 @@ void ResizeBilinearFunctor::operator()( CalculateResizeScale(in_width, out_width, align_corners_); uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, height_scale); @@ -67,9 +72,6 @@ void ResizeBilinearFunctor::operator()( kernel_.setArg(idx++, static_cast(in_height)); kernel_.setArg(idx++, static_cast(in_width)); kernel_.setArg(idx++, static_cast(out_height)); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/slice.cc b/mace/kernels/opencl/slice.cc index 94f541b2..d610e1e1 100644 --- a/mace/kernels/opencl/slice.cc +++ b/mace/kernels/opencl/slice.cc @@ -65,12 +65,14 @@ void SliceFunctor::operator()( << outputs_count; for (int i = 0; i < outputs_count; ++i) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(channel_blk * i)); kernel_.setArg(idx++, *(output_list[i]->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); } diff --git a/mace/kernels/opencl/softmax_opencl.cc b/mace/kernels/opencl/softmax_opencl.cc index 6b06cc8f..61ea0228 100644 --- a/mace/kernels/opencl/softmax_opencl.cc +++ b/mace/kernels/opencl/softmax_opencl.cc @@ -45,13 +45,15 @@ void SoftmaxFunctor::operator()(const Tensor *logits, } if (!IsVecEqual(input_shape_, logits->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } kernel_.setArg(idx++, *(logits->opencl_image())); kernel_.setArg(idx++, static_cast(channels)); kernel_.setArg(idx++, remain_channels); kernel_.setArg(idx++, *(output->opencl_image())); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); input_shape_ = logits->shape(); diff --git a/mace/kernels/opencl/space_to_batch_opencl.cc b/mace/kernels/opencl/space_to_batch_opencl.cc index 6e00f6ea..38da0548 100644 --- a/mace/kernels/opencl/space_to_batch_opencl.cc +++ b/mace/kernels/opencl/space_to_batch_opencl.cc @@ -57,6 +57,11 @@ void SpaceToBatchFunctor::operator()( } if (!IsVecEqual(space_shape_, space_tensor->shape())) { uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } if (b2s_) { kernel_.setArg(idx++, *(batch_tensor->opencl_image())); kernel_.setArg(idx++, *(space_tensor->opencl_image())); @@ -72,9 +77,6 @@ void SpaceToBatchFunctor::operator()( kernel_.setArg(idx++, static_cast(space_tensor->dim(2))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(1))); kernel_.setArg(idx++, static_cast(batch_tensor->dim(2))); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - kernel_.setArg(idx++, gws[2]); space_shape_ = space_tensor->shape(); diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index 905b1346..092a60cd 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -61,6 +61,10 @@ void WinogradTransformFunctor::operator()( output_tensor->ResizeImage(output_shape, image_shape); uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + } kernel_.setArg(idx++, *(input_tensor->opencl_image())); kernel_.setArg(idx++, *(output_tensor->opencl_image())); kernel_.setArg(idx++, static_cast(input_tensor->dim(1))); @@ -70,8 +74,6 @@ void WinogradTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, static_cast(paddings[0] / 2)); kernel_.setArg(idx++, static_cast(paddings[1] / 2)); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensor->shape(); @@ -151,6 +153,10 @@ void WinogradInverseTransformFunctor::operator()( const uint32_t round_h = (height_ + 1) / 2; const uint32_t round_w = (width_ + 1) / 2; uint32_t idx = 0; + if (!is_non_uniform_work_groups_supported_) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + } kernel_.setArg( idx++, *(static_cast(input_tensor->opencl_image()))); @@ -165,8 +171,6 @@ void WinogradInverseTransformFunctor::operator()( kernel_.setArg(idx++, static_cast(round_h * round_w)); kernel_.setArg(idx++, static_cast(round_w)); kernel_.setArg(idx++, relux_max_limit_); - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); input_shape_ = input_tensor->shape(); diff --git a/tools/mace_tools.py b/tools/mace_tools.py index 4f2b209a..2e0ea3fa 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -376,7 +376,8 @@ def main(unused_args): build_run_throughput_test(target_soc, FLAGS.run_seconds, merged_lib_file, FLAGS.output_dir) - packaging_lib_file(FLAGS.output_dir) + if FLAGS.mode == "build" or FLAGS.mode == "all": + packaging_lib_file(FLAGS.output_dir) if __name__ == "__main__": diff --git a/tools/packaging_lib.sh b/tools/packaging_lib.sh index 60751487..c6158cd5 100644 --- a/tools/packaging_lib.sh +++ b/tools/packaging_lib.sh @@ -14,8 +14,13 @@ source ${CURRENT_DIR}/env.sh LIBMACE_BUILD_DIR=$1 +TAR_PACKAGE_NAME=libmace_${PROJECT_NAME}.tar.gz + pushd $LIBMACE_BUILD_DIR/$PROJECT_NAME -ls | grep -v build | xargs tar cvzf libmace_${PROJECT_NAME}.tar.gz +if [ -f $TAR_PACKAGE_NAME ]; then + rm -f $TAR_PACKAGE_NAME +fi +ls | grep -v build | xargs tar cvzf $TAR_PACKAGE_NAME popd echo "Packaging done!" -- GitLab