diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index b39e28b8178e16d1b3959c441fb3d31e0421c3f0..f2f5be1bba3bebf9e3d43b5563778e5e8ebf5179 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -323,21 +323,21 @@ void OpenCLRuntime::GetCallStats(const cl::Event &event, CallStats *stats) { } } -uint32_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { +uint64_t OpenCLRuntime::GetDeviceMaxWorkGroupSize() { uint64_t size = 0; device_->getInfo(CL_DEVICE_MAX_WORK_GROUP_SIZE, &size); - return static_cast(size); + return size; } -uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel &kernel) { +uint64_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel &kernel) { uint64_t size = 0; kernel.getWorkGroupInfo(*device_, CL_KERNEL_WORK_GROUP_SIZE, &size); - return static_cast(size); + return size; } // TODO(liuqi): not compatible with mali gpu. -uint32_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) { - uint32_t size = 0; +uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) { + uint64_t size = 0; kernel.getWorkGroupInfo(*device_, CL_KERNEL_WAVE_SIZE_QCOM, &size); return size; } diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index 69ea4233c3f7012850b2738ba639cb295d3cd580..f5e2c25bb5d0a62b0bc403791bb8f18df9b80938 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -46,9 +46,9 @@ class OpenCLRuntime { cl::CommandQueue &command_queue(); void GetCallStats(const cl::Event &event, CallStats *stats); - uint32_t GetDeviceMaxWorkGroupSize(); - uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel); - uint32_t GetKernelWaveSize(const cl::Kernel &kernel); + uint64_t GetDeviceMaxWorkGroupSize(); + uint64_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel); + uint64_t GetKernelWaveSize(const cl::Kernel &kernel); cl::Kernel BuildKernel(const std::string &program_name, const std::string &kernel_name, const std::set &build_options); diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index 772a6d8d0c17774de35dca46e96fd9a15c94c38c..f4b7b2223349e45b9f7d02976a0d4184ea3ee0ad 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -62,11 +62,13 @@ void FCWXKernel(cl::Kernel *kernel, const index_t batch = output->dim(0); const index_t output_size = output->dim(3); const index_t output_blocks = RoundUpDiv4(output_size); - const uint32_t wave_size = runtime->GetKernelWaveSize(*kernel); + const uint32_t wave_size = + static_cast(runtime->GetKernelWaveSize(*kernel)); *gws = {4, (wave_size / 4), static_cast(batch * output_blocks)}; - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(*kernel); + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]); *lws = {(*gws)[0], (*gws)[1], inter_local_blks}; } diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index e3cadbc6f5d1cd73b7f5b6a2de02c370a19ce0c1..ee52625a6337bb9be5390b4392fd5b93e5a88214 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -201,7 +201,8 @@ void TuningOrRun3DKernel(const cl::Kernel &kernel, StatsFuture *future) { auto runtime = OpenCLRuntime::Global(); auto params_generator = [&]() -> std::vector> { - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel)); std::vector local_ws(3, 0); local_ws[0] = std::min(gws[0], kwg_size); local_ws[1] = std::min(gws[1], kwg_size / local_ws[0]); @@ -304,7 +305,8 @@ void TuningOrRun2DKernel(const cl::Kernel &kernel, StatsFuture *future) { auto runtime = OpenCLRuntime::Global(); auto params_generator = [&]() -> std::vector> { - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(kernel); + const uint32_t kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel)); uint32_t local_ws[2]; local_ws[0] = std::min(gws[0], kwg_size); local_ws[1] = std::min(gws[1], kwg_size / local_ws[0]);