diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 23bedc6eb30a2753ba67c966c8102385dd293dd0..82efa4595ae6e4f091ce9618dc4e1b16199f382d 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -20,9 +20,12 @@ void BatchNormFunctor::operator()( const Tensor *epsilon, Tensor *output) { + index_t pixel_size = input->dim(2) * input->dim(3); + index_t blocks = (pixel_size + 3) / 4; + const uint32_t gws[3] = {static_cast(input->dim(0)), static_cast(input->dim(1)), - static_cast(input->dim(2) * input->dim(3))}; + static_cast(blocks)}; auto runtime = OpenCLRuntime::Get(); @@ -39,10 +42,10 @@ void BatchNormFunctor::operator()( bm_kernel.setArg(idx++, *(static_cast(mean->buffer()))); bm_kernel.setArg(idx++, *(static_cast(var->buffer()))); bm_kernel.setArg(idx++, *(static_cast(epsilon->buffer()))); - bm_kernel.setArg(idx++, gws[2]); + bm_kernel.setArg(idx++, static_cast(pixel_size)); bm_kernel.setArg(idx++, *(static_cast(output->buffer()))); - bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); - bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); + bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr); + bm_kernel.setArg(idx++, lws[1] * sizeof(float) * 4, nullptr); auto params_generator = [&kwg_size]()->std::vector> { return {{1, 1, 64}, diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index e86c62f336dd62b2aab0d5226f084756eff8350d..3b7dcf08c4f0a5fee6eddc3bde102c9fabff6ac1 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -6,8 +6,8 @@ void kernel batch_norm(global const float *input, global const float *epsilon, private const uint pixels, global float *output, - __local float *new_scale, - __local float *new_offset) { + __local float4 *new_scale, + __local float4 *new_offset) { const int batch = get_global_id(0); const int channel = get_global_id(1); const int channels = get_global_size(1); @@ -16,15 +16,26 @@ void kernel batch_norm(global const float *input, const int local_pixel_idx = get_local_id(2); if(local_pixel_idx == 0) { - new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon); - new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel]; + new_scale[local_channel] = (float4)(scale[channel] * rsqrt(var[channel] + *epsilon)); + new_offset[local_channel] = (float4)(offset[channel] - mean[channel] * new_scale[local_channel].x); } barrier(CLK_LOCAL_MEM_FENCE); - const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; + const int sample_offset = (batch * channels + channel) * pixels + pixel_offset*4; const float *input_ptr = input + sample_offset; float *output_ptr = output + sample_offset; - *output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; + const int end = (batch * channels + channel + 1) * pixels; + if ((sample_offset+4) > end) { + for (int i = sample_offset; i < end; ++i) { + *output_ptr = new_scale[local_channel].x * *input_ptr + new_offset[local_channel].x; + ++input_ptr; + ++output_ptr; + } + } else { + float4 values = vload4(0, input_ptr); + values = values * new_scale[local_channel] + new_offset[local_channel]; + vstore4(values, 0, output_ptr); + } }