diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 57bc85069a6eb81ff2d578373f4f4c73a5351504..1cd3086d5f74ca66ce145faff4e1a31e310521cf 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -206,7 +206,11 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, if (batchs > 1) { for (int i = 0; i < non_zero_num; i++) { if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { - offsets[batchs_ptr[i]] = i + 1; + const int start = batchs_ptr[i]; + const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1]; + for (int j = start; j < end; j++) { + offsets[j] = i + 1; + } } } } else { @@ -214,7 +218,6 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, } for (int b = 0; b < batchs; b++) { - if (offsets[b] == 0) continue; int batch_start = 0; int batch_non_zero_num = offsets[b]; if (b > 0) { @@ -233,6 +236,9 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) { csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; } + if (batch_non_zero_num == 0) { + memset(csr_crows_data + b * (rows + 1), 0, sizeof(IntT) * (rows + 1)); + } } memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 94022d6392eea662fd49efddee8fa4f1d6b5730d..1ed4ebd23db87ebf353844e4b8daf029de41dba2 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" @@ -283,19 +284,24 @@ void SparseCsrToCooKernel(const Context& dev_ctx, template __global__ void GetBatchsOffset(const IntT* batchs_ptr, + const int batchs, const int non_zero_num, - IntT* batchs_offset) { + int* batchs_offset) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { - batchs_offset[batchs_ptr[i]] = i + 1; + const int start = batchs_ptr[i]; + const int end = i == non_zero_num - 1 ? batchs : batchs_ptr[i + 1]; + for (int j = start; j < end; j++) { + batchs_offset[j] = i + 1; + } } } } template __global__ void ConvertCooRowsToCsrCrows( - const IntT* batchs_offset, // can be null if batchs = 1 + const int* batchs_offset, // can be null if batchs = 1 const IntT* coo_rows_data, IntT* csr_crows_data, const int rows, @@ -303,12 +309,12 @@ __global__ void ConvertCooRowsToCsrCrows( const int b = blockIdx.y; int batch_non_zero_num = batchs_offset == nullptr ? non_zero_num : batchs_offset[b]; - if (batch_non_zero_num == 0) return; IntT batch_start = 0; if (b > 0) { batch_start = batchs_offset[b - 1]; batch_non_zero_num -= batch_start; } + const IntT* coo_rows_ptr = coo_rows_data + batch_start; const int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) { @@ -328,6 +334,11 @@ __global__ void ConvertCooRowsToCsrCrows( } } } + if (batch_non_zero_num == 0) { + for (int i = tid; i < rows + 1; i += gridDim.x * blockDim.x) { + csr_crows_data[b * (rows + 1) + i] = 0; + } + } } template @@ -365,13 +376,19 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); if (batchs > 1) { - phi::DenseTensor batchs_offset = phi::Empty(dev_ctx, {batchs}); - IntT* batchs_offset_ptr = batchs_offset.data(); - GetBatchsOffset - <<>>(batchs_ptr, non_zero_num, batchs_offset_ptr); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); + phi::DenseTensor batchs_offset = phi::Empty(dev_ctx, {batchs}); + int* batchs_offset_ptr = batchs_offset.data(); + phi::funcs::SetConstant set_zero; + // set zero if the nnz=0 of batchs[0] + set_zero(dev_ctx, &batchs_offset, static_cast(0)); + GetBatchsOffset<<>>( + batchs_ptr, batchs, non_zero_num, batchs_offset_ptr); + config.block_per_grid.y = batchs; ConvertCooRowsToCsrCrows<<