From 41ba27229243d5f222d094360295ad8b9d9fc438 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 29 Nov 2022 09:58:44 +0800 Subject: [PATCH] group the index in not cutlass mode (#48439) --- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 38 ++++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index e5e3cd0f5c1..87037581e52 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -123,25 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); } - if (subm) { - auto config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); - unique_value.ResizeAndAllocate( - {static_cast(out->nnz() * kernel_size)}); - out_index.ResizeAndAllocate({static_cast(rulebook_len)}); - int* out_index_ptr = out_index.data(); - int* unique_value_ptr = unique_value.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); - GroupIndexs<<>>(rulebook_len, - kernel_size, - rulebook_ptr + rulebook_len, - out_index_ptr, - unique_value_ptr); - } #ifdef PADDLE_WITH_CUTLASS bool cutlass = true; if (dev_ctx.GetComputeCapability() < 80) cutlass = false; @@ -226,6 +207,25 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, } } else { #endif + if (subm) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + unique_value.ResizeAndAllocate( + {static_cast(out->nnz() * kernel_size)}); + out_index.ResizeAndAllocate({static_cast(rulebook_len)}); + int* out_index_ptr = out_index.data(); + int* unique_value_ptr = unique_value.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); + GroupIndexs<<>>(rulebook_len, + kernel_size, + rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); + } // 2. gather phi::DenseTensor in_features = phi::Empty(dev_ctx, {rulebook_len, in_channels}); -- GitLab