From 23e0668058ad55e0ffb9217b57549abbd712feac Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 21 Sep 2022 16:50:19 +0800 Subject: [PATCH] [Sparse]Conv sort out (#46216) * sort out index --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 2a524eb4650..77eea316290 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/phi/kernels/sparse/conv_kernel.h" @@ -186,8 +187,7 @@ __global__ void UniqueKernel(const IntT* in_indexs, if (i < rulebook_len) { // atomicOr only support int int index = static_cast(in_indexs[i]); - int change_index = index == 0 ? -1 : index; - int flag = atomicOr(out_index_table + index, change_index); + int flag = atomicOr(out_index_table + index, 1); if (flag == 0) { int j = atomicAdd(&count, 1); cache[j] = index; @@ -772,6 +772,7 @@ int ProductRuleBook(const Context& dev_ctx, phi::backends::gpu::GpuMemsetAsync( out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync( unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); @@ -785,6 +786,7 @@ int ProductRuleBook(const Context& dev_ctx, out_index_table_ptr, out_index_ptr, unique_key_ptr); + int out_nnz = 0; phi::backends::gpu::GpuMemcpyAsync(&out_nnz, unique_key_ptr, @@ -792,6 +794,13 @@ int ProductRuleBook(const Context& dev_ctx, gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); +#ifdef PADDLE_WITH_HIP + thrust::sort(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), +#endif + out_index_ptr, + out_index_ptr + out_nnz); const int64_t sparse_dim = 4; phi::DenseTensor out_indices = -- GitLab