From 8fbe97e446b2cdb4c06a200af123712fb667e238 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 21 Sep 2022 19:21:07 +0800 Subject: [PATCH] Revert "SparseConv support duplicate coordinates (#44976)" (#45202) This reverts commit e8de9dfd3b981eff47507011717ba201c5f75604. --- paddle/phi/kernels/funcs/sparse/scatter.cu.h | 9 +- paddle/phi/kernels/sparse/gpu/conv.cu.h | 147 ++++-------------- .../kernels/sparse/gpu/conv_grad_kernel.cu | 40 +---- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 25 ++- 4 files changed, 53 insertions(+), 168 deletions(-) diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h index 6c293b23944..f27174d5818 100644 --- a/paddle/phi/kernels/funcs/sparse/scatter.cu.h +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -79,7 +79,6 @@ __global__ void ScatterKernelV2(const T* input, const int* index_groups, const int non_zero_num, const int kernel_size, - const int max_voxel, const int channels, const int buffer_counts, T* out) { @@ -97,11 +96,10 @@ __global__ void ScatterKernelV2(const T* input, &sums); for (int it = 0; it < buffer_counts; it++) { int len = index_counts[indices_i + it * non_zero_num]; - const int group_offset = it * max_voxel * kernel_size * non_zero_num; + const int group_offset = it * kernel_size * non_zero_num; for (int j = 0; j < len; j++) { const int out_feature_i = - index_groups[indices_i * max_voxel * kernel_size + j + - group_offset]; + index_groups[indices_i * kernel_size + j + group_offset]; LoadT vec_in; phi::Load( input + out_feature_i * channels + channels_i * VecSize, &vec_in); @@ -123,7 +121,6 @@ void ScatterV2(const GPUContext& dev_ctx, const int* index_groups, const int non_zero_num, const int kernel_size, - const int max_voxel, const int channels, const int buffer_counts, T* output) { @@ -139,7 +136,6 @@ void ScatterV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, - max_voxel, channels, buffer_counts, output); @@ -154,7 +150,6 @@ void ScatterV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, - max_voxel, channels, buffer_counts, output); diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 77eea316290..161930a06fa 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -66,7 +66,6 @@ __global__ void GatherKernelV2(const T* inputs, const int* index_groups, const int non_zero_num, const int kernel_size, - const int max_voxel, const int channels, const int buffer_count, T* output) { @@ -84,11 +83,10 @@ __global__ void GatherKernelV2(const T* inputs, #pragma unroll for (int it = 0; it < buffer_count; it++) { int len = index_counts[indices_i + it * non_zero_num]; - const int group_offset = it * kernel_size * max_voxel * non_zero_num; + const int group_offset = it * kernel_size * non_zero_num; #pragma unroll for (int j = 0; j < len; j++) { - int out_i = index_groups[indices_i * kernel_size * max_voxel + j + - group_offset]; + int out_i = index_groups[indices_i * kernel_size + j + group_offset]; phi::Store( in_vec, output + out_i * channels + channels_i * VecSize); } @@ -130,7 +128,6 @@ inline void GatherV2(const GPUContext& dev_ctx, const int* index_groups, const int non_zero_num, const int kernel_size, - const int max_voxel, const int channels, const int buffer_count, T* output) { @@ -146,7 +143,6 @@ inline void GatherV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, - max_voxel, channels, buffer_count, output); @@ -161,7 +157,6 @@ inline void GatherV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, - max_voxel, channels, buffer_count, output); @@ -207,7 +202,7 @@ __global__ void UniqueKernel(const IntT* in_indexs, template __global__ void GroupIndexs(const int* out_index_table, const int n, - const int offset, + const int kernel_size, IntT* out_indexs, int* out_index_counts, int* out_index_groups) { @@ -219,7 +214,7 @@ __global__ void GroupIndexs(const int* out_index_table, // kernel_size at most int j = atomicAdd(out_index_counts + real_index, 1); // nnz * kernel_size - out_index_groups[real_index * offset + j] = i; + out_index_groups[real_index * kernel_size + j] = i; } } @@ -303,36 +298,18 @@ __global__ void ProductRuleBookKernel(const T* x_indices, } } -template +template __global__ void GetOutIndexTable(const IntT* indices, const IntT non_zero_num, const Dims4D dims, - int* out_index_table, - int* out_index_table2, - int* max_voxel) { - __shared__ int cache_max; - if (threadIdx.x == 0) { - cache_max = 0; - } - __syncthreads(); - + int* out_index_table) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT batch = indices[i]; IntT in_z = indices[i + non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num]; IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); - if (save_out_index) { - out_index_table[index] = i == 0 ? -1 : i; - } - - int count = atomicAdd(out_index_table2 + index, 1); - atomicMax(&cache_max, count); - } - - __syncthreads(); - if (threadIdx.x == 0) { - atomicMax(max_voxel, cache_max + 1); + out_index_table[index] = i == 0 ? -1 : i; } } @@ -341,22 +318,10 @@ __global__ void GetOutIndexTable(int* indexs, const int non_zero_num, const Dims4D out_dims, int* out_index_table, - int* out_index_table2, - int* max_voxel, IntT* out_indices) { - __shared__ int cache_max; - if (threadIdx.x == 0) { - cache_max = 0; - } - __syncthreads(); - CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT index = static_cast(indexs[i]); out_index_table[index] = i; - - int count = atomicAdd(out_index_table2 + index, 1); - atomicMax(&cache_max, count); - IntT batch, x, y, z; phi::funcs::sparse::IndexToPoint( index, out_dims, &batch, &x, &y, &z); @@ -367,11 +332,6 @@ __global__ void GetOutIndexTable(int* indexs, out_indices[i + non_zero_num * 3] = x; indexs[i] = 0; } - - __syncthreads(); - if (threadIdx.x == 0) { - atomicMax(max_voxel, cache_max + 1); - } } template @@ -491,7 +451,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, template __global__ void GroupIndexs(const int n, - const int offset, + const int kernel_size, const IntT* indexs, int* index_counts, int* index_groups) { @@ -500,7 +460,7 @@ __global__ void GroupIndexs(const int n, // kernel_size at most int j = atomicAdd(index_counts + index, 1); // nnz * kernel_size - index_groups[index * offset + j] = i; + index_groups[index * kernel_size + j] = i; } } @@ -508,7 +468,7 @@ __global__ void GroupIndexs(const int n, template __global__ void GroupIndexsV2(const int rulebook_len, const int non_zero_num, - const int offset, + const int kernel_size, const int half_kernel_offset, const IntT* indexs, int* index_counts, @@ -519,11 +479,11 @@ __global__ void GroupIndexsV2(const int rulebook_len, i < half_kernel_offset ? index_counts : index_counts + non_zero_num; int* groups_ptr = i < half_kernel_offset ? index_groups - : index_groups + non_zero_num * offset; + : index_groups + non_zero_num * kernel_size; // conflict kernel_size times at most int j = atomicAdd(counts_ptr + index, 1); // nnz * kernel_size - groups_ptr[index * offset + j] = i; + groups_ptr[index * kernel_size + j] = i; } } @@ -622,10 +582,6 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensor out_index_table = phi::Empty(dev_ctx, {table_size}); int* out_index_table_ptr = out_index_table.data(); - DenseTensor out_index_table2 = phi::Empty(dev_ctx, {table_size + 1}); - int* out_index_table2_ptr = out_index_table2.data(); - int* h_max_voxel = h_counter + kernel_size; - if (subm) { DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); IntT* rulebook_ptr = tmp_rulebook.data(); @@ -636,29 +592,14 @@ int ProductRuleBook(const Context& dev_ctx, phi::backends::gpu::GpuMemsetAsync( out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr, - 0, - sizeof(int) * (table_size + 1), - dev_ctx.stream()); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - GetOutIndexTable - <<>>(out_indices.data(), - non_zero_num, - d_x_dims, - out_index_table_ptr, - out_index_table2_ptr, - out_index_table2_ptr + table_size); - phi::backends::gpu::GpuMemcpyAsync(h_max_voxel, - out_index_table2_ptr + table_size, - sizeof(int), - gpuMemcpyDeviceToHost, - dev_ctx.stream()); - dev_ctx.Wait(); + GetOutIndexTable<<>>( + out_indices.data(), non_zero_num, d_x_dims, out_index_table_ptr); size_t cache_size = kernel_size * 2 * sizeof(int) + @@ -712,22 +653,6 @@ int ProductRuleBook(const Context& dev_ctx, out_rulebook_ptr); *rulebook = out_rulebook; - unique_value->ResizeAndAllocate( - {static_cast(non_zero_num * h_max_voxel[0] * kernel_size)}); - int* unique_value_ptr = unique_value->data(); - out_index->ResizeAndAllocate({static_cast(rulebook_len)}); - int* out_index_ptr = out_index->data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); - GroupIndexs<<>>(rulebook_len, - kernel_size * h_max_voxel[0], - out_rulebook_ptr + rulebook_len, - out_index_ptr, - unique_value_ptr); - return rulebook_len; } else { @@ -811,35 +736,17 @@ int ProductRuleBook(const Context& dev_ctx, IntT* out_indices_ptr = out_indices.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); - phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr, - 0, - sizeof(int) * (table_size + 1), - dev_ctx.stream()); - config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1); - GetOutIndexTable - <<>>(out_index_ptr, - out_nnz, - d_out_dims, - out_index_table_ptr, - out_index_table2_ptr, - out_index_table2_ptr + table_size, - out_indices_ptr); - phi::backends::gpu::GpuMemcpyAsync(h_max_voxel, - out_index_table2_ptr + table_size, - sizeof(int), - gpuMemcpyDeviceToHost, - dev_ctx.stream()); - dev_ctx.Wait(); - + GetOutIndexTable<<>>(out_index_ptr, + out_nnz, + d_out_dims, + out_index_table_ptr, + out_indices_ptr); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); - unique_value->ResizeAndAllocate( - {static_cast(out_nnz * h_max_voxel[0] * kernel_size)}); + unique_value->ResizeAndAllocate({static_cast(out_nnz * kernel_size)}); int* unique_value_ptr = unique_value->data(); GroupIndexs<<>>(out_index_table_ptr, rulebook_len, - kernel_size * h_max_voxel[0], + kernel_size, rulebook_ptr + rulebook_len, out_index_ptr, unique_value_ptr); diff --git a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu index 5d57afab403..adfdb09968c 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu @@ -119,44 +119,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, } } - int max_voxel = counter_ptr[kernel_size]; - if (!subm) { - const auto& x_dims = x.dims(); - Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); - int64_t table_size = 1; - for (int i = 0; i < x_dims.size() - 1; i++) { - table_size *= x_dims[i]; - } - DenseTensor in_index_table = phi::Empty(dev_ctx, {table_size + 1}); - int* in_index_table_ptr = in_index_table.data(); - phi::backends::gpu::GpuMemsetAsync(in_index_table_ptr, - 0, - sizeof(int) * (table_size + 1), - dev_ctx.stream()); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x.nnz(), 1); - GetOutIndexTable - <<>>(x.indices().data(), - x.nnz(), - d_x_dims, - nullptr, - in_index_table_ptr, - in_index_table_ptr + table_size); - - phi::backends::gpu::GpuMemcpyAsync(&max_voxel, - in_index_table_ptr + table_size, - sizeof(int), - gpuMemcpyDeviceToHost, - dev_ctx.stream()); - dev_ctx.Wait(); - } - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); DenseTensor unique_value = phi::Empty( - dev_ctx, {static_cast(x_grad->nnz() * max_voxel * kernel_size * 2)}); + dev_ctx, {static_cast(x_grad->nnz() * kernel_size * 2)}); DenseTensor out_index = phi::Empty(dev_ctx, {static_cast(x.nnz() * 2)}); int* out_index_ptr = out_index.data(); @@ -169,7 +135,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, 0, dev_ctx.stream()>>>(rulebook_len, x.nnz(), - kernel_size * max_voxel, + kernel_size, offsets[kernel_size / 2], rulebook_ptr, out_index_ptr, @@ -181,7 +147,6 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, unique_value_ptr, x.nnz(), kernel_size, - max_voxel, in_channels, 2, in_features_ptr); @@ -242,7 +207,6 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, unique_value.data(), x_grad->nnz(), kernel_size, - max_voxel, in_channels, 2, x_grad_values_ptr); diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index e5727c4faab..282033e62e3 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; DenseTensor h_counter, h_offsets; - h_counter.Resize({kernel_size + 1}); + h_counter.Resize({kernel_size}); h_offsets.Resize({kernel_size + 1}); int* h_counter_ptr = dev_ctx.template HostAlloc(&h_counter); int* h_offsets_ptr = dev_ctx.template HostAlloc(&h_offsets); @@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook - DenseTensor counter_per_kernel = phi::Empty(dev_ctx, {kernel_size + 1}); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, {kernel_size}); DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, {kernel_size}); DenseTensor out_index = phi::Empty(dev_ctx, {1}); DenseTensor unique_value = phi::Empty(dev_ctx, {1}); @@ -143,6 +143,26 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, T* out_values_ptr = out_values->data(); set_zero(dev_ctx, out_values, static_cast(0.0f)); + if (subm) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + unique_value.ResizeAndAllocate( + {static_cast(out->nnz() * kernel_size)}); + out_index.ResizeAndAllocate({static_cast(rulebook_len)}); + int* out_index_ptr = out_index.data(); + int* unique_value_ptr = unique_value.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); + GroupIndexs<<>>(rulebook_len, + kernel_size, + rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); + } + const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { if (h_counter_ptr[i] <= 0) { @@ -176,7 +196,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, unique_value.data(), out->nnz(), kernel_size, - h_counter_ptr[kernel_size], out_channels, 1, out_values_ptr); -- GitLab