diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index b97b7a9bb9e4fd5e1e271c0d04833410b0c99d0f..b85b9e48074f0d354aa5c85a12ad2a03e0a4f0a8 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -160,16 +160,16 @@ cl::Program &OpenCLRuntime::program() { return program_; } -int OpenCLRuntime::GetDeviceMaxWorkGroupSize() { +uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { unsigned long long size = 0; device_.getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); - return static_cast(size); + return static_cast(size); } -int OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { +uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel& kernel) { unsigned long long size = 0; kernel.getWorkGroupInfo(device_, CL_KERNEL_WORK_GROUP_SIZE, &size); - return static_cast(size); + return static_cast(size); } } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index 26e8bd83e996587068e3161d72b527ff98c7903d..e7c7b180e43ba13dfdfedec9a63b5973bdfcb55a 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -21,8 +21,8 @@ class OpenCLRuntime { public: static OpenCLRuntime *Get(); - int GetDeviceMaxWorkGroupSize(); - int GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); + uint32_t GetDeviceMaxWorkGroupSize(); + uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel& kernel); cl::Context &context(); cl::Device &device(); cl::CommandQueue &command_queue(); diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 3140b2b91ac82db990a4c96dbe01fd078af26259..4d0771f15819e71496100e441799801f35b191f2 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -22,13 +22,15 @@ void BatchNormFunctor::operator()( const uint32_t gws[3] = {static_cast(input->dim(0)), static_cast(input->dim(1)), static_cast(input->dim(2) * input->dim(3))}; - const uint32_t lws[3] = {1, 2, 128}; auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); auto bm_kernel = cl::Kernel(program, "batch_norm"); + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); + const uint32_t lws[3] = {1, kwg_size/128, 128}; + uint32_t idx = 0; bm_kernel.setArg(idx++, *(static_cast(input->buffer()))); bm_kernel.setArg(idx++, *(static_cast(scale->buffer()))); @@ -41,9 +43,6 @@ void BatchNormFunctor::operator()( bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); bm_kernel.setArg(idx++, lws[1] * sizeof(float), nullptr); - MACE_CHECK(std::accumulate(lws, lws+3, 1, std::multiplies()) - < runtime->GetKernelMaxWorkGroupSize(bm_kernel)); - cl_int error = runtime->command_queue().enqueueNDRangeKernel( bm_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),