From 2fef9af5162bb4c3eacf1b3c5fd91355d64d25b8 Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 26 Oct 2017 11:24:01 +0800 Subject: [PATCH] Change the calculation of local work group size. --- mace/core/runtime/opencl/opencl_runtime.cc | 8 ++++---- mace/core/runtime/opencl/opencl_runtime.h | 4 ++-- mace/kernels/opencl/batch_norm_opencl.cc | 7 +++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index b97b7a9b..b85b9e48 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -160,16 +160,16 @@ cl::Program &OpenCLRuntime::program() { return program_; } -int OpenCLRuntime::GetDeviceMaxWorkGroupSize() { +uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { unsigned long long size = 0; device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); - return static_cast(size); + return static_cast(size); } -int OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { +uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { unsigned long long size = 0; kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size); - return static_cast(size); + return static_cast(size); } } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index 26e8bd83..e7c7b180 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -21,8 +21,8 @@ class OpenCLRuntime { public: static OpenCLRuntime *Get(); - int GetDeviceMaxWorkGroupSize(); - int GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); + uint32_t GetDeviceMaxWorkGroupSize(); + uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); cl::Context &context(); cl::Device &device(); cl::CommandQueue &command_queue(); diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 3140b2b9..4d0771f1 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -22,13 +22,15 @@ void BatchNormFunctor::operator()( const uint32_t gws[3] = {static_cast(input->dim(0)), static_cast(input->dim(1)), static_cast(input->dim(2) * input->dim(3))}; - const uint32_t lws[3] = {1, 2, 128}; auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); auto bm_kernel = cl::Kernel(program, "batch_norm"); + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); + const uint32_t lws[3] = {1, kwg_size/128, 128}; + uint32_t idx = 0; bm_kernel.setArg(idx++, *(static_cast(input->buffer()))); bm_kernel.setArg(idx++, *(static_cast(scale->buffer()))); @@ -41,9 +43,6 @@ void BatchNormFunctor::operator()( bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); - MACE_CHECK(std::accumulate(lws, lws+3, 1, std::multiplies()) - < runtime->GetKernelMaxWorkGroupSize(bm_kernel)); - cl_int error = runtime->command_queue().enqueueNDRangeKernel( bm_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), -- GitLab