diff --git a/paddle/phi/kernels/funcs/sparse/utils.cu.h b/paddle/phi/kernels/funcs/sparse/utils.cu.h index 074fe1ca420497689cf7d6942bfe9c2709e5b191..f3b742dfc38cd516503be2993844b32de8b9cd2f 100644 --- a/paddle/phi/kernels/funcs/sparse/utils.cu.h +++ b/paddle/phi/kernels/funcs/sparse/utils.cu.h @@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) { } } +inline __device__ bool SetBits(const int value, int* ptr) { + const int index = value >> 5; + const int mask = 1 << (value & 31); + const int old = atomicOr(ptr + index, mask); + return (mask & old) != 0; +} + +inline __device__ bool TestBits(const int value, const int* ptr) { + const int index = value >> 5; + const int mask = 1 << (value & 31); + return (mask & ptr[index]) != 0; +} + } // namespace sparse } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 161930a06fa854f49209f51f870e8c050ceb6e41..8618171b8f905aca24ab30caf020957e0a28f002 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -167,7 +167,7 @@ inline void GatherV2(const GPUContext& dev_ctx, template __global__ void UniqueKernel(const IntT* in_indexs, const int rulebook_len, - int* out_index_table, + int* index_flags, int* out_indexs, int* nnz) { extern __shared__ int cache[]; @@ -182,8 +182,8 @@ __global__ void UniqueKernel(const IntT* in_indexs, if (i < rulebook_len) { // atomicOr only support int int index = static_cast(in_indexs[i]); - int flag = atomicOr(out_index_table + index, 1); - if (flag == 0) { + const bool flag = phi::funcs::sparse::SetBits(index, index_flags); + if (!flag) { int j = atomicAdd(&count, 1); cache[j] = index; } @@ -284,7 +284,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices, atomicAdd(&counter_buf[kernel_index], 1); kernel_i = kernel_index; } - // rulebook[kernel_index * non_zero_num + i] = kernel_i; rulebook[kernel_index * non_zero_num + i] = in_i; rulebook[kernel_index * non_zero_num + offset + i] = out_index; ++kernel_index; @@ -299,17 +298,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices, } template -__global__ void GetOutIndexTable(const IntT* indices, - const IntT non_zero_num, - const Dims4D dims, - int* out_index_table) { +__global__ void GetOutIndexTable1(const IntT* indices, + const IntT non_zero_num, + const Dims4D dims, + int* index_flags, + int* out_index_table) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT batch = indices[i]; IntT in_z = indices[i + non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num]; IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); - out_index_table[index] = i == 0 ? -1 : i; + phi::funcs::sparse::SetBits(index, index_flags); + out_index_table[index] = i; } } @@ -375,6 +376,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, const Dims4D paddings, const Dims4D dilations, const Dims4D strides, + const int* index_flags, const int* out_index_table, T* rulebook, int* counter) { @@ -417,9 +419,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); - int real_out_index = out_index_table[out_index]; - if (real_out_index != 0) { - real_out_index = real_out_index == -1 ? 0 : real_out_index; + const bool flag = + phi::funcs::sparse::TestBits(out_index, index_flags); + if (flag) { + int real_out_index = out_index_table[out_index]; in_i = i; int buf_i = atomicAdd(&counter_buf[kernel_index], 1); kernel_i = kernel_index; @@ -440,7 +443,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, __syncthreads(); for (int i = 0; i < kernel_size; i++) { if (threadIdx.x < counter_buf[i]) { - // rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i; rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = rulebook_buf[i * blockDim.x + threadIdx.x]; rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] = @@ -575,12 +577,18 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensorMeta rulebook_meta( indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); - int64_t table_size = 1; + int table_size = 1; for (int i = 0; i < out_dims.size() - 1; i++) { table_size *= out_dims[i]; } DenseTensor out_index_table = phi::Empty(dev_ctx, {table_size}); int* out_index_table_ptr = out_index_table.data(); + // index_flags: flag the indices exist or not + int index_flags_size = (table_size + 31) / 32; + DenseTensor index_flags = phi::Empty(dev_ctx, {index_flags_size}); + int* index_flags_ptr = index_flags.data(); + phi::backends::gpu::GpuMemsetAsync( + index_flags_ptr, 0, sizeof(int) * index_flags.numel(), dev_ctx.stream()); if (subm) { DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); @@ -590,16 +598,16 @@ int ProductRuleBook(const Context& dev_ctx, phi::Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, &out_indices); - phi::backends::gpu::GpuMemsetAsync( - out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - GetOutIndexTable<<>>( - out_indices.data(), non_zero_num, d_x_dims, out_index_table_ptr); + GetOutIndexTable1<<>>(out_indices.data(), + non_zero_num, + d_x_dims, + index_flags_ptr, + out_index_table_ptr); size_t cache_size = kernel_size * 2 * sizeof(int) + @@ -625,6 +633,7 @@ int ProductRuleBook(const Context& dev_ctx, d_paddings, d_dilations, d_strides, + index_flags_ptr, out_index_table_ptr, rulebook_ptr, counter_ptr); @@ -695,9 +704,6 @@ int ProductRuleBook(const Context& dev_ctx, int* out_index_ptr = out_index->data(); int* unique_key_ptr = unique_key.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - phi::backends::gpu::GpuMemsetAsync( unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); @@ -708,7 +714,7 @@ int ProductRuleBook(const Context& dev_ctx, cache_size, dev_ctx.stream()>>>(rulebook_ptr + rulebook_len, rulebook_len, - out_index_table_ptr, + index_flags_ptr, out_index_ptr, unique_key_ptr); diff --git a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu index c4d2a691a4b3bd4b8fe19207d8e8e6daa2bd7f74..45f827801bc10d2cb39f517dacfa6524aa297e6d 100644 --- a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" +#include "paddle/phi/kernels/funcs/sparse/utils.cu.h" namespace phi { namespace sparse { @@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx, } template -__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { +__global__ void MaskTable(const IntT* x_indexs, + const int n, + int* index_flags, + int* table) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { int index = x_indexs[i]; - table[index] = i == 0 ? -1 : i; + phi::funcs::sparse::SetBits(index, index_flags); + table[index] = i; } } template __global__ void MaskCopy(const IntT* mask_indexs, + const int* index_flags, const int* table, const int n, const int stride, @@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs, using LoadT = phi::AlignedVector; using StoreT = phi::AlignedVector; CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - int j = table[mask_indexs[i]]; - if (j != 0) { - if (j == -1) j = 0; + const int mask_index = mask_indexs[i]; + const bool flag = phi::funcs::sparse::TestBits(mask_index, index_flags); + if (flag) { + int j = table[mask_index]; for (int k = 0; k < stride; k += VecSize) { LoadT vec_x; phi::Load(x_values + j * stride + k, &vec_x); @@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, int table_size = 1; auto x_dims = x.dims(); - for (int i = 0; i < x_dims.size() - 1; i++) { + for (int i = 0; i < sparse_dim; i++) { table_size *= x_dims[i]; } DenseTensor table = phi::Empty(dev_ctx, {table_size}); - phi::backends::gpu::GpuMemsetAsync( - table.data(), 0, table_size * sizeof(int), dev_ctx.stream()); + DenseTensor index_flags = phi::Empty(dev_ctx, {(table_size + 31) / 32}); + phi::backends::gpu::GpuMemsetAsync(index_flags.data(), + 0, + index_flags.numel() * sizeof(int), + dev_ctx.stream()); const int64_t stride = x.dims().size() == sparse_dim ? 1 : x.values().dims()[1]; *out = phi::EmptyLike(dev_ctx, x.values()); @@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, MaskTable<<>>( - x_indexs_ptr, x_indexs.numel(), table.data()); + dev_ctx.stream()>>>(x_indexs_ptr, + x_indexs.numel(), + index_flags.data(), + table.data()); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); @@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, config.thread_per_block, 0, dev_ctx.stream()>>>(mask_indexs_ptr, + index_flags.data(), table.data(), mask_indexs.numel(), stride, @@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, config.thread_per_block, 0, dev_ctx.stream()>>>(mask_indexs_ptr, + index_flags.data(), table.data(), mask_indexs.numel(), stride,