From 5752643b6467948a5e06b41e9688d7358b5c5a25 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 13 Jun 2022 10:22:34 +0800 Subject: [PATCH] sparse convertion kernel support secondary dispatch (#43345) * use GpuMemcpy and GpuMemset * sparse convert kernel support double dispatch by indices dtype * cudaMemcpyKind->gpuMemcpyKind --- paddle/phi/backends/gpu/gpu_types.h | 10 + .../kernels/sparse/cpu/sparse_utils_kernel.cc | 120 +++--- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 362 ++++++++---------- .../tests/unittests/test_sparse_utils_op.py | 52 +-- 4 files changed, 270 insertions(+), 274 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_types.h b/paddle/phi/backends/gpu/gpu_types.h index 05f6c5a545c..77f403795b6 100644 --- a/paddle/phi/backends/gpu/gpu_types.h +++ b/paddle/phi/backends/gpu/gpu_types.h @@ -67,6 +67,16 @@ DECLARE_CONSTANT_FOR_GPU(gpuErrorOutOfMemory, DECLARE_CONSTANT_FOR_GPU(gpuErrorNotReady, cudaErrorNotReady, hipErrorNotReady); DECLARE_CONSTANT_FOR_GPU(gpuSuccess, cudaSuccess, hipSuccess); +DECLARE_CONSTANT_FOR_GPU(gpuMemcpyHostToDevice, + cudaMemcpyKind::cudaMemcpyHostToDevice, + hipMemcpyKind::hipMemcpyHostToDevice); +DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToHost, + cudaMemcpyKind::cudaMemcpyDeviceToHost, + hipMemcpyKind::hipMemcpyDeviceToHost); +DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToDevice, + cudaMemcpyKind::cudaMemcpyDeviceToDevice, + hipMemcpyKind::hipMemcpyDeviceToDevice); + #undef DECLARE_CONSTANT_FOR_GPU } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 28b1b3368ed..57bc85069a6 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" namespace phi { @@ -68,20 +69,23 @@ void DenseToSparseCooKernel(const Context& dev_ctx, SparseCooTensor* out) { const T* x_data = x.data(); const auto& x_dims = x.dims(); + PADDLE_ENFORCE_LE(sparse_dim, + x_dims.size(), + phi::errors::InvalidArgument( + "sparse_dim must be less than the size of x.dims()")); + PADDLE_ENFORCE_GT( + sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0")); int64_t non_zero_num = GetNonZeroNum(x, sparse_dim); - const auto place = dev_ctx.GetPlace(); const auto values_dims = phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num); - DenseTensorMeta indices_meta(DataType::INT64, - {sparse_dim, static_cast(non_zero_num)}, - DataLayout::NCHW); DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout); - phi::DenseTensor indices = phi::Empty(dev_ctx, std::move(indices_meta)); + phi::DenseTensor indices = + phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); phi::DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta)); - int64_t* indices_data = indices.mutable_data(place); - T* values_data = values.mutable_data(place); + int64_t* indices_data = indices.data(); + T* values_data = values.data(); auto dims_2d = flatten_to_2d(x_dims, sparse_dim); const int rows = dims_2d[0]; @@ -102,36 +106,32 @@ void DenseToSparseCooKernel(const Context& dev_ctx, out->SetMember(indices, values, x_dims, true); } -template -void SparseCsrToCooKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { +template +void SparseCsrToCooCPUKernel(const CPUContext& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.non_zero_cols().numel(); const auto& csr_crows = x.non_zero_crows(); const auto& csr_cols = x.non_zero_cols(); const auto& csr_values = x.non_zero_elements(); - const int64_t* csr_crows_data = csr_crows.data(); - const int64_t* csr_cols_data = csr_cols.data(); + const IntT* csr_crows_data = csr_crows.data(); + const IntT* csr_cols_data = csr_cols.data(); const T* csr_values_data = csr_values.data(); int64_t sparse_dim = 2; if (x_dims.size() == 3) { sparse_dim = 3; } - const auto place = dev_ctx.GetPlace(); - DenseTensorMeta indices_meta( - DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW); - DenseTensorMeta values_meta( - x.dtype(), {non_zero_num}, x.non_zero_elements().layout()); - phi::DenseTensor indices = phi::Empty(dev_ctx, std::move(indices_meta)); - phi::DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta)); - int64_t* coo_indices = indices.mutable_data(place); - int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; - int64_t* coo_rows_data = + phi::DenseTensor indices = + phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); + phi::DenseTensor values = phi::Empty(dev_ctx, {non_zero_num}); + IntT* coo_indices = indices.data(); + IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; + IntT* coo_rows_data = x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num; - int64_t* coo_cols_data = coo_rows_data + non_zero_num; - T* coo_values_data = values.mutable_data(place); + IntT* coo_cols_data = coo_rows_data + non_zero_num; + T* coo_values_data = values.data(); int batch = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; @@ -139,7 +139,7 @@ void SparseCsrToCooKernel(const Context& dev_ctx, int index = 0; for (int b = 0; b < batch; b++) { for (int i = 0; i < rows; i++) { - for (int j = csr_crows_data[b * (rows + 1) + i]; + for (IntT j = csr_crows_data[b * (rows + 1) + i]; j < csr_crows_data[b * (rows + 1) + i + 1]; j++) { coo_rows_data[index] = i; @@ -151,15 +151,25 @@ void SparseCsrToCooKernel(const Context& dev_ctx, } } - memcpy(coo_cols_data, csr_cols_data, sizeof(int64_t) * non_zero_num); + memcpy(coo_cols_data, csr_cols_data, sizeof(IntT) * non_zero_num); memcpy(coo_values_data, csr_values_data, sizeof(T) * non_zero_num); out->SetMember(indices, values, x_dims, true); } template -void SparseCooToCsrKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { +void SparseCsrToCooKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "SparseCsrToCooCPUKernel", ([&] { + SparseCsrToCooCPUKernel(dev_ctx, x, out); + })); +} + +template +void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { const auto& x_dims = x.dims(); bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, @@ -174,11 +184,11 @@ void SparseCooToCsrKernel(const Context& dev_ctx, phi::DenseTensor non_zero_crows; non_zero_crows.Resize({batchs * (rows + 1)}); - int64_t* csr_crows_data = dev_ctx.template Alloc(&non_zero_crows); + IntT* csr_crows_data = dev_ctx.template Alloc(&non_zero_crows); phi::DenseTensor non_zero_cols; non_zero_cols.Resize({non_zero_num}); - int64_t* csr_cols_data = dev_ctx.template Alloc(&non_zero_cols); + IntT* csr_cols_data = dev_ctx.template Alloc(&non_zero_cols); phi::DenseTensor non_zero_elements; non_zero_elements.Resize({non_zero_num}); @@ -186,16 +196,12 @@ void SparseCooToCsrKernel(const Context& dev_ctx, const auto& coo_indices = x.non_zero_indices(); const auto& coo_values = x.non_zero_elements(); - const int64_t* batchs_ptr = coo_indices.data(); - const int64_t* coo_rows_data = + const IntT* batchs_ptr = coo_indices.data(); + const IntT* coo_rows_data = batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; - const int64_t* coo_cols_data = coo_rows_data + non_zero_num; + const IntT* coo_cols_data = coo_rows_data + non_zero_num; const T* coo_values_data = coo_values.data(); - if (!x.coalesced()) { - // TODO(zhangkahuo): call coalesced() to distinct and sort the indices - } - std::vector offsets(batchs, 0); if (batchs > 1) { for (int i = 0; i < non_zero_num; i++) { @@ -220,25 +226,34 @@ void SparseCooToCsrKernel(const Context& dev_ctx, csr_crows_data[b * (rows + 1) + i] = 0; } for (int64_t i = 1; i < batch_non_zero_num; i++) { - for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { + for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { csr_crows_data[b * (rows + 1) + j + 1] = i; } } - for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; - i++) { + for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) { csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; } } - memcpy(csr_cols_data, coo_cols_data, sizeof(int64_t) * non_zero_num); + memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num); memcpy(csr_values_data, coo_values_data, sizeof(T) * non_zero_num); out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); } template -void SparseCooToDenseKernel(const Context& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +void SparseCooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseCooToCsrCPUKernel", ([&] { + SparseCooToCsrCPUKernel(dev_ctx, x, out); + })); +} + +template +void SparseCooToDenseCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { const auto non_zero_num = x.nnz(); const auto dense_dims = x.dims(); const auto indices = x.non_zero_indices(); @@ -270,8 +285,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx, for (auto i = 0; i < non_zero_num; i++) { int64_t index = 0; for (int j = 0; j < sparse_dim; j++) { - index += - indices.data()[j * non_zero_num + i] * sparse_offsets[j]; + index += indices.data()[j * non_zero_num + i] * sparse_offsets[j]; } for (int j = 0; j < base_offset; j++) { @@ -280,6 +294,16 @@ void SparseCooToDenseKernel(const Context& dev_ctx, } } +template +void SparseCooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseCooToDenseCPUKernel", ([&] { + SparseCooToDenseCPUKernel(dev_ctx, x, out); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 38553d1fe1d..94022d6392e 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -15,11 +15,12 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" @@ -96,39 +97,33 @@ void DenseToSparseCooKernel(const Context& dev_ctx, SparseCooTensor* out) { const T* x_data = x.data(); const auto& x_dims = x.dims(); + PADDLE_ENFORCE_LE(sparse_dim, + x_dims.size(), + phi::errors::InvalidArgument( + "sparse_dim must be less than the size of x.dims()")); + PADDLE_ENFORCE_GT( + sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0")); auto dims_2d = flatten_to_2d(x_dims, sparse_dim); const int rows = dims_2d[0]; const int cols = dims_2d[1]; - auto nums_meta = - phi::DenseTensorMeta(DataType::INT32, {1}, phi::DataLayout::NCHW); - DenseTensor nums = phi::Empty(dev_ctx, std::move(nums_meta)); - auto x_dims_meta = phi::DenseTensorMeta(DataType::INT64, - {static_cast(x_dims.size())}, - phi::DataLayout::NCHW); - DenseTensor d_x_dims = phi::Empty(dev_ctx, std::move(x_dims_meta)); - - const auto place = dev_ctx.GetPlace(); + DenseTensor nums = phi::Empty(dev_ctx, {1}); + DenseTensor d_x_dims = phi::Empty(dev_ctx, {x_dims.size()}); // 1. get numbers of non zero elements, and get the index of non zero elements - int* nums_ptr = nums.mutable_data(place); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream())); -#endif + int* nums_ptr = nums.data(); + phi::backends::gpu::GpuMemsetAsync( + nums_ptr, 0, sizeof(int), dev_ctx.stream()); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1); - auto temp_indexs_meta = - phi::DenseTensorMeta(DataType::INT32, {rows}, phi::DataLayout::NCHW); - DenseTensor temp_indexs = phi::Empty(dev_ctx, std::move(temp_indexs_meta)); - int* temp_indexs_ptr = temp_indexs.mutable_data(place); + DenseTensor temp_indexs = phi::Empty(dev_ctx, {rows}); + int* temp_indexs_ptr = temp_indexs.data(); + GetNonZeroNums<<>>( x_data, rows, cols, nums_ptr, temp_indexs_ptr); + #ifdef PADDLE_WITH_HIP thrust::remove(thrust::hip::par.on(dev_ctx.stream()), #else @@ -140,35 +135,16 @@ void DenseToSparseCooKernel(const Context& dev_ctx, // 2. copy non_zero_num to host, copy x_dims to device int non_zero_num = 0; -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(&non_zero_num, - nums_ptr, - sizeof(int), - hipMemcpyDeviceToHost, - dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&non_zero_num, - nums_ptr, - sizeof(int), - cudaMemcpyDeviceToHost, - dev_ctx.stream())); -#endif - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemcpyAsync(d_x_dims.mutable_data(place), - x_dims.Get(), - x_dims.size() * sizeof(x_dims[0]), - hipMemcpyHostToDevice, - dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(d_x_dims.mutable_data(place), - x_dims.Get(), - x_dims.size() * sizeof(x_dims[0]), - cudaMemcpyHostToDevice, - dev_ctx.stream())); -#endif + phi::backends::gpu::GpuMemcpyAsync(&non_zero_num, + nums_ptr, + sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync(d_x_dims.data(), + x_dims.Get(), + x_dims.size() * sizeof(x_dims[0]), + gpuMemcpyHostToDevice, + dev_ctx.stream()); dev_ctx.Wait(); // wait the copy @@ -197,20 +173,22 @@ void DenseToSparseCooKernel(const Context& dev_ctx, out->SetMember(indices, values, x_dims, true); } -__global__ void GetBatchSizes(const int64_t* crows, +template +__global__ void GetBatchSizes(const IntT* crows, const int rows, const int batchs, - int* batch_sizes) { + IntT* batch_sizes) { const int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid < batchs) { batch_sizes[tid] = crows[tid * (rows + 1) + rows]; } } -__global__ void ConvertCsrCrowsToCooRows(const int64_t* crows_ptr, - const int* crows_offsets, - int64_t* rows_ptr, - int64_t* batch_ptr, +template +__global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr, + const IntT* crows_offsets, + IntT* rows_ptr, + IntT* batch_ptr, const int rows) { const int b = blockIdx.y; const int64_t offset = crows_offsets ? crows_offsets[b] : 0; @@ -227,17 +205,17 @@ __global__ void ConvertCsrCrowsToCooRows(const int64_t* crows_ptr, } } -template -void SparseCsrToCooKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - SparseCooTensor* out) { +template +void SparseCsrToCooGPUKernel(const GPUContext& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.non_zero_cols().numel(); const auto& csr_crows = x.non_zero_crows(); const auto& csr_cols = x.non_zero_cols(); const auto& csr_values = x.non_zero_elements(); - const int64_t* csr_crows_data = csr_crows.data(); - const int64_t* csr_cols_data = csr_cols.data(); + const IntT* csr_crows_data = csr_crows.data(); + const IntT* csr_cols_data = csr_cols.data(); const T* csr_values_data = csr_values.data(); int64_t sparse_dim = 2; @@ -247,26 +225,20 @@ void SparseCsrToCooKernel(const Context& dev_ctx, int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; - const auto place = dev_ctx.GetPlace(); - DenseTensorMeta indices_meta( - DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW); - DenseTensorMeta values_meta( - x.dtype(), {non_zero_num}, x.non_zero_elements().layout()); - DenseTensorMeta offsets_meta(DataType::INT32, {batchs}, DataLayout::NCHW); - DenseTensor indices = phi::Empty(dev_ctx, std::move(indices_meta)); - DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta)); - DenseTensor offsets = phi::Empty(dev_ctx, std::move(offsets_meta)); - int64_t* coo_indices = indices.mutable_data(place); - int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; - int64_t* coo_rows_data = + DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); + DenseTensor values = phi::EmptyLike(dev_ctx, csr_values); + DenseTensor offsets = phi::Empty(dev_ctx, {batchs}); + IntT* coo_indices = indices.data(); + IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; + IntT* coo_rows_data = x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num; - int64_t* coo_cols_data = coo_rows_data + non_zero_num; - int* offsets_ptr = batchs == 1 ? nullptr : offsets.mutable_data(place); - T* coo_values_data = values.mutable_data(place); + IntT* coo_cols_data = coo_rows_data + non_zero_num; + IntT* offsets_ptr = batchs == 1 ? nullptr : offsets.data(); + T* coo_values_data = values.data(); if (batchs > 1) { auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); - GetBatchSizes<<>>( + GetBatchSizes<<>>( csr_crows_data, rows, batchs, offsets_ptr); #ifdef PADDLE_WITH_HIP @@ -281,40 +253,38 @@ void SparseCsrToCooKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1); config.block_per_grid.y = batchs; - ConvertCsrCrowsToCooRows<<>>( - csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows); - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(coo_cols_data, - csr_cols_data, - sizeof(int64_t) * non_zero_num, - hipMemcpyDeviceToDevice, - dev_ctx.stream())); - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(coo_values_data, - csr_values_data, - sizeof(T) * non_zero_num, - hipMemcpyDeviceToDevice, - dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(coo_cols_data, - csr_cols_data, - sizeof(int64_t) * non_zero_num, - cudaMemcpyDeviceToDevice, - dev_ctx.stream())); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(coo_values_data, - csr_values_data, - sizeof(T) * non_zero_num, - cudaMemcpyDeviceToDevice, - dev_ctx.stream())); -#endif + ConvertCsrCrowsToCooRows + <<>>( + csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows); + + phi::backends::gpu::GpuMemcpyAsync(coo_cols_data, + csr_cols_data, + sizeof(IntT) * non_zero_num, + gpuMemcpyDeviceToDevice, + dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync(coo_values_data, + csr_values_data, + sizeof(T) * non_zero_num, + gpuMemcpyDeviceToDevice, + dev_ctx.stream()); out->SetMember(indices, values, x_dims, true); } -__global__ void GetBatchsOffset(const int64_t* batchs_ptr, +template +void SparseCsrToCooKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCooTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "SparseCsrToCooGPUKernel", ([&] { + SparseCsrToCooGPUKernel(dev_ctx, x, out); + })); +} + +template +__global__ void GetBatchsOffset(const IntT* batchs_ptr, const int non_zero_num, - int64_t* batchs_offset) { + IntT* batchs_offset) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { if (i == non_zero_num - 1 || batchs_ptr[i] != batchs_ptr[i + 1]) { @@ -323,35 +293,36 @@ __global__ void GetBatchsOffset(const int64_t* batchs_ptr, } } +template __global__ void ConvertCooRowsToCsrCrows( - const int64_t* batchs_offset, // can be null if batchs = 1 - const int64_t* coo_rows_data, - int64_t* csr_crows_data, + const IntT* batchs_offset, // can be null if batchs = 1 + const IntT* coo_rows_data, + IntT* csr_crows_data, const int rows, const int64_t non_zero_num) { const int b = blockIdx.y; int batch_non_zero_num = batchs_offset == nullptr ? non_zero_num : batchs_offset[b]; if (batch_non_zero_num == 0) return; - int batch_start = 0; + IntT batch_start = 0; if (b > 0) { batch_start = batchs_offset[b - 1]; batch_non_zero_num -= batch_start; } - auto* coo_rows_ptr = coo_rows_data + batch_start; + const IntT* coo_rows_ptr = coo_rows_data + batch_start; const int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < batch_non_zero_num; i += gridDim.x * blockDim.x) { if (i == 0) { - for (int j = 0; j <= coo_rows_ptr[0]; j++) { + for (IntT j = 0; j <= coo_rows_ptr[0]; j++) { csr_crows_data[b * (rows + 1) + j] = 0; } } else { - for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { + for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) { csr_crows_data[b * (rows + 1) + j + 1] = i; } } if (i == batch_non_zero_num - 1) { - for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; + for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) { csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num; } @@ -359,10 +330,10 @@ __global__ void ConvertCooRowsToCsrCrows( } } -template -void SparseCooToCsrKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCsrTensor* out) { +template +void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { const auto& x_dims = x.dims(); bool valid = x_dims.size() == 2 || x_dims.size() == 3; PADDLE_ENFORCE_EQ(valid, @@ -376,78 +347,71 @@ void SparseCooToCsrKernel(const Context& dev_ctx, int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; phi::DenseTensor non_zero_crows = - phi::Empty(dev_ctx, {batchs * (rows + 1)}); - phi::DenseTensor non_zero_cols = phi::Empty(dev_ctx, {non_zero_num}); - phi::DenseTensor non_zero_elements = phi::Empty(dev_ctx, {non_zero_num}); - int64_t* csr_crows_data = non_zero_crows.data(); - int64_t* csr_cols_data = non_zero_cols.data(); + phi::Empty(dev_ctx, {batchs * (rows + 1)}); + phi::DenseTensor non_zero_cols = phi::Empty(dev_ctx, {non_zero_num}); + phi::DenseTensor non_zero_elements = + phi::EmptyLike(dev_ctx, x.non_zero_elements()); + IntT* csr_crows_data = non_zero_crows.data(); + IntT* csr_cols_data = non_zero_cols.data(); T* csr_values_data = non_zero_elements.data(); const auto& coo_indices = x.non_zero_indices(); const auto& coo_values = x.non_zero_elements(); - const int64_t* batchs_ptr = coo_indices.data(); - const int64_t* coo_rows_data = + const IntT* batchs_ptr = coo_indices.data(); + const IntT* coo_rows_data = batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; - const int64_t* coo_cols_data = coo_rows_data + non_zero_num; + const IntT* coo_cols_data = coo_rows_data + non_zero_num; const T* coo_values_data = coo_values.data(); - if (!x.coalesced()) { - // TODO(zhangkahuo): call coalesced() to distinct and sort the indices - } - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); if (batchs > 1) { - DenseTensorMeta batchs_meta(DataType::INT64, {batchs}, DataLayout::NCHW); - phi::DenseTensor batchs_offset = phi::Empty(dev_ctx, {batchs}); - int64_t* batchs_offset_ptr = batchs_offset.data(); - GetBatchsOffset<<>>( - batchs_ptr, non_zero_num, batchs_offset_ptr); + phi::DenseTensor batchs_offset = phi::Empty(dev_ctx, {batchs}); + IntT* batchs_offset_ptr = batchs_offset.data(); + GetBatchsOffset + <<>>(batchs_ptr, non_zero_num, batchs_offset_ptr); config.block_per_grid.y = batchs; - ConvertCooRowsToCsrCrows<<>>( + ConvertCooRowsToCsrCrows<<>>( batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num); } else { - ConvertCooRowsToCsrCrows<<>>( + ConvertCooRowsToCsrCrows<<>>( nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num); } -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_cols_data, - coo_cols_data, - sizeof(int64_t) * non_zero_num, - hipMemcpyDeviceToDevice, - dev_ctx.stream())); - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync(csr_values_data, - coo_values_data, - sizeof(T) * non_zero_num, - hipMemcpyDeviceToDevice, - dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_cols_data, - coo_cols_data, - sizeof(int64_t) * non_zero_num, - cudaMemcpyDeviceToDevice, - dev_ctx.stream())); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(csr_values_data, - coo_values_data, - sizeof(T) * non_zero_num, - cudaMemcpyDeviceToDevice, - dev_ctx.stream())); -#endif + phi::backends::gpu::GpuMemcpyAsync(csr_cols_data, + coo_cols_data, + sizeof(IntT) * non_zero_num, + gpuMemcpyDeviceToDevice, + dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync(csr_values_data, + coo_values_data, + sizeof(T) * non_zero_num, + gpuMemcpyDeviceToDevice, + dev_ctx.stream()); out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims); } +template +void SparseCooToCsrKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCsrTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseCooToCsrGPUKernel", ([&] { + SparseCooToCsrGPUKernel(dev_ctx, x, out); + })); +} + template __global__ void KernelSparseCooToDense(const IndicesT* indices, - const IndicesT* sparse_offsets, + const int64_t* sparse_offsets, const ValueT* data, ValueT* dense_data, const IndicesT non_zero_num, @@ -466,10 +430,10 @@ __global__ void KernelSparseCooToDense(const IndicesT* indices, } } -template -void SparseCooToDenseKernel(const Context& dev_ctx, - const SparseCooTensor& x, - DenseTensor* out) { +template +void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { const auto non_zero_num = x.nnz(); const auto dense_dims = x.dims(); const auto indices = x.non_zero_indices(); @@ -498,38 +462,24 @@ void SparseCooToDenseKernel(const Context& dev_ctx, offset *= dense_dims[i]; } - auto sparse_offset_meta = phi::DenseTensorMeta( - DataType::INT64, {sparse_dim}, phi::DataLayout::NCHW); - DenseTensor d_sparse_offsets = Empty(dev_ctx, std::move(sparse_offset_meta)); + DenseTensor d_sparse_offsets = Empty(dev_ctx, {sparse_dim}); + + phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), + sparse_offsets.data(), + sparse_dim * sizeof(int64_t), + gpuMemcpyHostToDevice, + dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync( + out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream()); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemcpyAsync(d_sparse_offsets.mutable_data(place), - sparse_offsets.data(), - sparse_dim * sizeof(int64_t), - hipMemcpyHostToDevice, - dev_ctx.stream())); - - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemsetAsync(out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(d_sparse_offsets.mutable_data(place), - sparse_offsets.data(), - sparse_dim * sizeof(int64_t), - cudaMemcpyHostToDevice, - dev_ctx.stream())); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream())); -#endif auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - KernelSparseCooToDense + KernelSparseCooToDense <<>>(indices.data(), + dev_ctx.stream()>>>(indices.data(), d_sparse_offsets.data(), x_data, out_data, @@ -538,6 +488,16 @@ void SparseCooToDenseKernel(const Context& dev_ctx, sparse_dim); } +template +void SparseCooToDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + DenseTensor* out) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseCooToDenseGPUKernel", ([&] { + SparseCooToDenseGPUKernel(dev_ctx, x, out); + })); +} + } // namespace sparse } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py index 5705763e0af..a72757d5005 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py @@ -168,31 +168,33 @@ class TestSparseConvert(unittest.TestCase): with _test_eager_guard(): indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]] values = [1.0, 2.0, 3.0, 4.0, 5.0] - sparse_x = paddle.incubate.sparse.sparse_coo_tensor( - paddle.to_tensor(indices), - paddle.to_tensor(values), - shape=[3, 4], - stop_gradient=False) - dense_tensor = sparse_x.to_dense() - #test to_dense_grad backward - out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0]] - dense_tensor.backward(paddle.to_tensor(out_grad)) - #mask the out_grad by sparse_x.indices() - correct_x_grad = [2.0, 4.0, 7.0, 9.0, 10.0] - assert np.array_equal(correct_x_grad, - sparse_x.grad.values().numpy()) - - paddle.device.set_device("cpu") - sparse_x_cpu = paddle.incubate.sparse.sparse_coo_tensor( - paddle.to_tensor(indices), - paddle.to_tensor(values), - shape=[3, 4], - stop_gradient=False) - dense_tensor_cpu = sparse_x_cpu.to_dense() - dense_tensor_cpu.backward(paddle.to_tensor(out_grad)) - assert np.array_equal(correct_x_grad, - sparse_x_cpu.grad.values().numpy()) + indices_dtypes = ['int32', 'int64'] + for indices_dtype in indices_dtypes: + sparse_x = paddle.incubate.sparse.sparse_coo_tensor( + paddle.to_tensor(indices, dtype=indices_dtype), + paddle.to_tensor(values), + shape=[3, 4], + stop_gradient=False) + dense_tensor = sparse_x.to_dense() + #test to_dense_grad backward + out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]] + dense_tensor.backward(paddle.to_tensor(out_grad)) + #mask the out_grad by sparse_x.indices() + correct_x_grad = [2.0, 4.0, 7.0, 9.0, 10.0] + assert np.array_equal(correct_x_grad, + sparse_x.grad.values().numpy()) + + paddle.device.set_device("cpu") + sparse_x_cpu = paddle.incubate.sparse.sparse_coo_tensor( + paddle.to_tensor(indices, dtype=indices_dtype), + paddle.to_tensor(values), + shape=[3, 4], + stop_gradient=False) + dense_tensor_cpu = sparse_x_cpu.to_dense() + dense_tensor_cpu.backward(paddle.to_tensor(out_grad)) + assert np.array_equal(correct_x_grad, + sparse_x_cpu.grad.values().numpy()) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) def test_to_sparse_csr(self): -- GitLab