diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index d68145e9585740aca644bcba2395e5e762f846ec..1dc3686c41990b6b408b09c11906dec913564808 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -662,9 +662,9 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx.stream()); dev_ctx.Wait(); - size_t cache_size = kernel_size * 2 + kernel_size * - config.thread_per_block.x * 2 * - sizeof(int); + size_t cache_size = + kernel_size * 2 * sizeof(int) + + kernel_size * config.thread_per_block.x * 2 * sizeof(int); const int MAX_CACHE_SIZE = 48 * 1024; while (cache_size >= MAX_CACHE_SIZE) { config.thread_per_block.x /= 2; @@ -672,7 +672,7 @@ int ProductRuleBook(const Context& dev_ctx, PADDLE_ENFORCE_GE(config.thread_per_block.x, 32, phi::errors::Fatal("the shared memory is not enough")); - cache_size = kernel_size * 2 + + cache_size = kernel_size * 2 * sizeof(int) + kernel_size * config.thread_per_block.x * 2 * sizeof(int); } ProductSubmRuleBookKernel<<