From aeb4c35e37a6c0d7ef9d1416fd1ee410fe003a21 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 30 Oct 2017 09:21:05 +0800 Subject: [PATCH] Use vector operation to optimize batch_norm opencl kernel. --- mace/kernels/opencl/batch_norm_opencl.cc | 11 +++++++---- mace/kernels/opencl/cl/batch_norm.cl | 23 +++++++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 23bedc6e..82efa459 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -20,9 +20,12 @@ void BatchNormFunctor::operator()( const Tensor *epsilon, Tensor *output) { + index_t pixel_size = input->dim(2) * input->dim(3); + index_t blocks = (pixel_size + 3) / 4; + const uint32_t gws[3] = {static_cast(input->dim(0)), static_cast(input->dim(1)), - static_cast(input->dim(2) * input->dim(3))}; + static_cast(blocks)}; auto runtime = OpenCLRuntime::Get(); @@ -39,10 +42,10 @@ void BatchNormFunctor::operator()( bm_kernel.setArg(idx++, *(static_cast(mean->buffer()))); bm_kernel.setArg(idx++, *(static_cast(var->buffer()))); bm_kernel.setArg(idx++, *(static_cast(epsilon->buffer()))); - bm_kernel.setArg(idx++, gws[2]); + bm_kernel.setArg(idx++, static_cast(pixel_size)); bm_kernel.setArg(idx++, *(static_cast(output->buffer()))); - bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); - bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); + bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr); + bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr); auto params_generator = [&kwg_size]()->std::vector> { return {{1, 1, 64}, diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index e86c62f3..3b7dcf08 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -6,8 +6,8 @@ void kernel batch_norm(global const float *input, global const float *epsilon, private const uint pixels, global float *output, - __local float *new_scale, - __local float *new_offset) { + __local float4 *new_scale, + __local float4 *new_offset) { const int batch = get_global_id(0); const int channel = get_global_id(1); const int channels = get_global_size(1); @@ -16,15 +16,26 @@ void kernel batch_norm(global const float *input, const int local_pixel_idx = get_local_id(2); if(local_pixel_idx == 0) { - new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon); - new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel]; + new_scale[local_channel] = (float4)(scale[channel] * rsqrt(var[channel] + *epsilon)); + new_offset[local_channel] = (float4)(offset[channel] - mean[channel] * new_scale[local_channel].x); } barrier(CLK_LOCAL_MEM_FENCE); - const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; + const int sample_offset = (batch * channels + channel) * pixels + pixel_offset*4; const float *input_ptr = input + sample_offset; float *output_ptr = output + sample_offset; - *output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; + const int end = (batch * channels + channel + 1) * pixels; + if ((sample_offset+4) > end) { + for (int i = sample_offset; i < end; ++i) { + *output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x; + ++input_ptr; + ++output_ptr; + } + } else { + float4 values = vload4(0, input_ptr); + values = values * new_scale[local_channel] + new_offset[local_channel]; + vstore4(values, 0, output_ptr); + } } -- GitLab