diff --git a/mace/kernels/opencl/depthwise_conv.cc b/mace/kernels/opencl/depthwise_conv.cc index ec358d092ef1da623190b55ff9b8da04a03bb1c4..ca44be2faa79b0b537fc5db7f416179e02e489c1 100644 --- a/mace/kernels/opencl/depthwise_conv.cc +++ b/mace/kernels/opencl/depthwise_conv.cc @@ -38,7 +38,7 @@ std::vector LocalWS(const uint32_t *gws, const uint32_t kwg_size) { kwg_size / lws[1]); } } - lws[0] = std::max(lws[0], 1); + lws[0] = std::max(std::min(lws[0], kwg_size / lws[1]), 1); const uint32_t lws_size = lws[0] * lws[1]; lws[2] = std::min((cache_size / kernel_cache_size / lws_size) * 4, gws[2]);