diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h index f27174d58181868c0e6f427d38de9e7c894eeafe..6c293b2394443af41b190345c01241b1e93c6dfb 100644 --- a/paddle/phi/kernels/funcs/sparse/scatter.cu.h +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -79,6 +79,7 @@ __global__ void ScatterKernelV2(const T* input, const int* index_groups, const int non_zero_num, const int kernel_size, + const int max_voxel, const int channels, const int buffer_counts, T* out) { @@ -96,10 +97,11 @@ __global__ void ScatterKernelV2(const T* input, &sums); for (int it = 0; it < buffer_counts; it++) { int len = index_counts[indices_i + it * non_zero_num]; - const int group_offset = it * kernel_size * non_zero_num; + const int group_offset = it * max_voxel * kernel_size * non_zero_num; for (int j = 0; j < len; j++) { const int out_feature_i = - index_groups[indices_i * kernel_size + j + group_offset]; + index_groups[indices_i * max_voxel * kernel_size + j + + group_offset]; LoadT vec_in; phi::Load( input + out_feature_i * channels + channels_i * VecSize, &vec_in); @@ -121,6 +123,7 @@ void ScatterV2(const GPUContext& dev_ctx, const int* index_groups, const int non_zero_num, const int kernel_size, + const int max_voxel, const int channels, const int buffer_counts, T* output) { @@ -136,6 +139,7 @@ void ScatterV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, + max_voxel, channels, buffer_counts, output); @@ -150,6 +154,7 @@ void ScatterV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, + max_voxel, channels, buffer_counts, output); diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 859857ed7baac5e5b06b4934cd44b03aedc13ad5..d68145e9585740aca644bcba2395e5e762f846ec 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -65,6 +65,7 @@ __global__ void GatherKernelV2(const T* inputs, const int* index_groups, const int non_zero_num, const int kernel_size, + const int max_voxel, const int channels, const int buffer_count, T* output) { @@ -82,10 +83,11 @@ __global__ void GatherKernelV2(const T* inputs, #pragma unroll for (int it = 0; it < buffer_count; it++) { int len = index_counts[indices_i + it * non_zero_num]; - const int group_offset = it * kernel_size * non_zero_num; + const int group_offset = it * kernel_size * max_voxel * non_zero_num; #pragma unroll for (int j = 0; j < len; j++) { - int out_i = index_groups[indices_i * kernel_size + j + group_offset]; + int out_i = index_groups[indices_i * kernel_size * max_voxel + j + + group_offset]; phi::Store( in_vec, output + out_i * channels + channels_i * VecSize); } @@ -127,6 +129,7 @@ inline void GatherV2(const GPUContext& dev_ctx, const int* index_groups, const int non_zero_num, const int kernel_size, + const int max_voxel, const int channels, const int buffer_count, T* output) { @@ -142,6 +145,7 @@ inline void GatherV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, + max_voxel, channels, buffer_count, output); @@ -156,6 +160,7 @@ inline void GatherV2(const GPUContext& dev_ctx, index_groups, non_zero_num, kernel_size, + max_voxel, channels, buffer_count, output); @@ -202,7 +207,7 @@ __global__ void UniqueKernel(const IntT* in_indexs, template __global__ void GroupIndexs(const int* out_index_table, const int n, - const int kernel_size, + const int offset, IntT* out_indexs, int* out_index_counts, int* out_index_groups) { @@ -214,7 +219,7 @@ __global__ void GroupIndexs(const int* out_index_table, // kernel_size at most int j = atomicAdd(out_index_counts + real_index, 1); // nnz * kernel_size - out_index_groups[real_index * kernel_size + j] = i; + out_index_groups[real_index * offset + j] = i; } } @@ -298,18 +303,36 @@ __global__ void ProductRuleBookKernel(const T* x_indices, } } -template +template __global__ void GetOutIndexTable(const IntT* indices, const IntT non_zero_num, const Dims4D dims, - int* out_index_table) { + int* out_index_table, + int* out_index_table2, + int* max_voxel) { + __shared__ int cache_max; + if (threadIdx.x == 0) { + cache_max = 0; + } + __syncthreads(); + CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT batch = indices[i]; IntT in_z = indices[i + non_zero_num]; IntT in_y = indices[i + 2 * non_zero_num]; IntT in_x = indices[i + 3 * non_zero_num]; IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); - out_index_table[index] = i == 0 ? -1 : i; + if (save_out_index) { + out_index_table[index] = i == 0 ? -1 : i; + } + + int count = atomicAdd(out_index_table2 + index, 1); + atomicMax(&cache_max, count); + } + + __syncthreads(); + if (threadIdx.x == 0) { + atomicMax(max_voxel, cache_max + 1); } } @@ -318,10 +341,22 @@ __global__ void GetOutIndexTable(int* indexs, const int non_zero_num, const Dims4D out_dims, int* out_index_table, + int* out_index_table2, + int* max_voxel, IntT* out_indices) { + __shared__ int cache_max; + if (threadIdx.x == 0) { + cache_max = 0; + } + __syncthreads(); + CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { IntT index = static_cast(indexs[i]); out_index_table[index] = i; + + int count = atomicAdd(out_index_table2 + index, 1); + atomicMax(&cache_max, count); + IntT batch, x, y, z; phi::funcs::sparse::IndexToPoint( index, out_dims, &batch, &x, &y, &z); @@ -332,6 +367,11 @@ __global__ void GetOutIndexTable(int* indexs, out_indices[i + non_zero_num * 3] = x; indexs[i] = 0; } + + __syncthreads(); + if (threadIdx.x == 0) { + atomicMax(max_voxel, cache_max + 1); + } } template @@ -451,7 +491,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices, template __global__ void GroupIndexs(const int n, - const int kernel_size, + const int offset, const IntT* indexs, int* index_counts, int* index_groups) { @@ -460,7 +500,7 @@ __global__ void GroupIndexs(const int n, // kernel_size at most int j = atomicAdd(index_counts + index, 1); // nnz * kernel_size - index_groups[index * kernel_size + j] = i; + index_groups[index * offset + j] = i; } } @@ -468,7 +508,7 @@ __global__ void GroupIndexs(const int n, template __global__ void GroupIndexsV2(const int rulebook_len, const int non_zero_num, - const int kernel_size, + const int offset, const int half_kernel_offset, const IntT* indexs, int* index_counts, @@ -479,11 +519,11 @@ __global__ void GroupIndexsV2(const int rulebook_len, i < half_kernel_offset ? index_counts : index_counts + non_zero_num; int* groups_ptr = i < half_kernel_offset ? index_groups - : index_groups + non_zero_num * kernel_size; + : index_groups + non_zero_num * offset; // conflict kernel_size times at most int j = atomicAdd(counts_ptr + index, 1); // nnz * kernel_size - groups_ptr[index * kernel_size + j] = i; + groups_ptr[index * offset + j] = i; } } @@ -582,6 +622,10 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensor out_index_table = phi::Empty(dev_ctx, {table_size}); int* out_index_table_ptr = out_index_table.data(); + DenseTensor out_index_table2 = phi::Empty(dev_ctx, {table_size + 1}); + int* out_index_table2_ptr = out_index_table2.data(); + int* h_max_voxel = h_counter + kernel_size; + if (subm) { DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); IntT* rulebook_ptr = tmp_rulebook.data(); @@ -594,14 +638,29 @@ int ProductRuleBook(const Context& dev_ctx, phi::backends::gpu::GpuMemsetAsync( out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr, + 0, + sizeof(int) * (table_size + 1), + dev_ctx.stream()); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - GetOutIndexTable<<>>( - out_indices.data(), non_zero_num, d_x_dims, out_index_table_ptr); + GetOutIndexTable + <<>>(out_indices.data(), + non_zero_num, + d_x_dims, + out_index_table_ptr, + out_index_table2_ptr, + out_index_table2_ptr + table_size); + phi::backends::gpu::GpuMemcpyAsync(h_max_voxel, + out_index_table2_ptr + table_size, + sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + dev_ctx.Wait(); size_t cache_size = kernel_size * 2 + kernel_size * config.thread_per_block.x * 2 * @@ -655,6 +714,22 @@ int ProductRuleBook(const Context& dev_ctx, out_rulebook_ptr); *rulebook = out_rulebook; + unique_value->ResizeAndAllocate( + {static_cast(non_zero_num * h_max_voxel[0] * kernel_size)}); + int* unique_value_ptr = unique_value->data(); + out_index->ResizeAndAllocate({static_cast(rulebook_len)}); + int* out_index_ptr = out_index->data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); + GroupIndexs<<>>(rulebook_len, + kernel_size * h_max_voxel[0], + out_rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); + return rulebook_len; } else { @@ -729,17 +804,35 @@ int ProductRuleBook(const Context& dev_ctx, IntT* out_indices_ptr = out_indices.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr, + 0, + sizeof(int) * (table_size + 1), + dev_ctx.stream()); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1); - GetOutIndexTable<<>>(out_index_ptr, - out_nnz, - d_out_dims, - out_index_table_ptr, - out_indices_ptr); + GetOutIndexTable + <<>>(out_index_ptr, + out_nnz, + d_out_dims, + out_index_table_ptr, + out_index_table2_ptr, + out_index_table2_ptr + table_size, + out_indices_ptr); + phi::backends::gpu::GpuMemcpyAsync(h_max_voxel, + out_index_table2_ptr + table_size, + sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + dev_ctx.Wait(); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); - unique_value->ResizeAndAllocate({static_cast(out_nnz * kernel_size)}); + unique_value->ResizeAndAllocate( + {static_cast(out_nnz * h_max_voxel[0] * kernel_size)}); int* unique_value_ptr = unique_value->data(); GroupIndexs<<>>(out_index_table_ptr, rulebook_len, - kernel_size, + kernel_size * h_max_voxel[0], rulebook_ptr + rulebook_len, out_index_ptr, unique_value_ptr); diff --git a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu index 8354068b7bcff763485f53aed442e7ee6489439f..9cbd75ed4ea99c4a3161b21a12339978dd010428 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu @@ -124,10 +124,44 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, } } + int max_voxel = counter_ptr[kernel_size]; + if (!subm) { + const auto& x_dims = x.dims(); + Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); + int64_t table_size = 1; + for (int i = 0; i < x_dims.size() - 1; i++) { + table_size *= x_dims[i]; + } + DenseTensor in_index_table = phi::Empty(dev_ctx, {table_size + 1}); + int* in_index_table_ptr = in_index_table.data(); + phi::backends::gpu::GpuMemsetAsync(in_index_table_ptr, + 0, + sizeof(int) * (table_size + 1), + dev_ctx.stream()); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x.nnz(), 1); + GetOutIndexTable + <<>>(x.non_zero_indices().data(), + x.nnz(), + d_x_dims, + nullptr, + in_index_table_ptr, + in_index_table_ptr + table_size); + + phi::backends::gpu::GpuMemcpyAsync(&max_voxel, + in_index_table_ptr + table_size, + sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + dev_ctx.Wait(); + } + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); DenseTensor unique_value = phi::Empty( - dev_ctx, {static_cast(x_grad->nnz() * kernel_size * 2)}); + dev_ctx, {static_cast(x_grad->nnz() * max_voxel * kernel_size * 2)}); DenseTensor out_index = phi::Empty(dev_ctx, {static_cast(x.nnz() * 2)}); int* out_index_ptr = out_index.data(); @@ -140,7 +174,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, 0, dev_ctx.stream()>>>(rulebook_len, x.nnz(), - kernel_size, + kernel_size * max_voxel, offsets[kernel_size / 2], rulebook_ptr, out_index_ptr, @@ -152,6 +186,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, unique_value_ptr, x.nnz(), kernel_size, + max_voxel, in_channels, 2, in_features_ptr); @@ -212,6 +247,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, unique_value.data(), x_grad->nnz(), kernel_size, + max_voxel, in_channels, 2, x_grad_values_ptr); diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index d313bf3c60f637ee08afb6a5bf1a18cbf4872d76..1a2b3134657e441f97ce086692c102768848ce67 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; DenseTensor h_counter, h_offsets; - h_counter.Resize({kernel_size}); + h_counter.Resize({kernel_size + 1}); h_offsets.Resize({kernel_size + 1}); int* h_counter_ptr = dev_ctx.template HostAlloc(&h_counter); int* h_offsets_ptr = dev_ctx.template HostAlloc(&h_offsets); @@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook - DenseTensor counter_per_kernel = phi::Empty(dev_ctx, {kernel_size}); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, {kernel_size + 1}); DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, {kernel_size}); DenseTensor out_index = phi::Empty(dev_ctx, {1}); DenseTensor unique_value = phi::Empty(dev_ctx, {1}); @@ -143,26 +143,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, T* out_values_ptr = out_values->data(); set_zero(dev_ctx, out_values, static_cast(0.0f)); - if (subm) { - auto config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); - unique_value.ResizeAndAllocate( - {static_cast(out->nnz() * kernel_size)}); - out_index.ResizeAndAllocate({static_cast(rulebook_len)}); - int* out_index_ptr = out_index.data(); - int* unique_value_ptr = unique_value.data(); - phi::backends::gpu::GpuMemsetAsync( - out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); - GroupIndexs<<>>(rulebook_len, - kernel_size, - rulebook_ptr + rulebook_len, - out_index_ptr, - unique_value_ptr); - } - const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { if (h_counter_ptr[i] <= 0) { @@ -196,6 +176,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, unique_value.data(), out->nnz(), kernel_size, + h_counter_ptr[kernel_size], out_channels, 1, out_values_ptr);