From acf3e526c1a69d777e7b48005b317f41f307f724 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 25 May 2023 15:18:14 +0800 Subject: [PATCH] [Sparse]fix sparse bug (#53390) --- paddle/phi/core/sparse_coo_tensor.h | 9 ++- paddle/phi/core/sparse_csr_tensor.h | 9 ++- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 30 ++++++--- paddle/phi/kernels/sparse/gpu/mask_kernel.cu | 10 ++- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 61 +++++++++++++------ .../tests/unittests/test_sparse_utils_op.py | 21 +++++++ 6 files changed, 104 insertions(+), 36 deletions(-) diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index 0e9273f321f..f0343585485 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -126,8 +126,13 @@ class SparseCooTensor : public TensorBase, bool valid() const noexcept override { return non_zero_elements_.valid(); } /// \brief Test whether the non_zero_elements_ storage is allocated. - /// return Whether the non_zero_elements_ storage is allocated. - bool initialized() const override { return non_zero_elements_.initialized(); } + /// In special cases, when nnz=0, non_zero_elements_ will not need to be + /// initialized, but it is neccessary to return true here, otherwise the + /// gradient will be None. return Whether the non_zero_elements_ storage is + /// allocated. + bool initialized() const override { + return values().initialized() || (nnz() == 0 && numel() > 0); + } /// \brief resize sparse coo tensor. /// \param dense_dims The dims of original dense tensor. diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 8692c8d7a20..38f330a7275 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -131,8 +131,13 @@ class SparseCsrTensor : public TensorBase, bool valid() const noexcept override { return non_zero_elements_.valid(); } /// \brief Test whether the non_zero_elements_ storage is allocated. - /// return Whether the non_zero_elements_ storage is allocated. - bool initialized() const override { return non_zero_elements_.initialized(); } + /// In special cases, when nnz=0, non_zero_elements_ will not need to be + /// initialized, but it is neccessary to return true here, otherwise the + /// gradient will be None. return Whether the non_zero_elements_ storage is + /// allocated. + bool initialized() const override { + return values().initialized() || (nnz() == 0 && numel() > 0); + } /// \brief resize sparse csr tensor. /// \param dense_dims The dims of original dense tensor. diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index fab853b3c56..d50d4568cc3 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -113,13 +113,6 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx, SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.cols().numel(); - const auto& csr_crows = x.crows(); - const auto& csr_cols = x.cols(); - const auto& csr_values = x.values(); - const IntT* csr_crows_data = csr_crows.data(); - const IntT* csr_cols_data = csr_cols.data(); - const T* csr_values_data = csr_values.data(); - int64_t sparse_dim = 2; if (x_dims.size() == 3) { sparse_dim = 3; @@ -127,6 +120,17 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx, phi::DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); phi::DenseTensor values = phi::Empty(dev_ctx, {non_zero_num}); + if (x.nnz() <= 0) { + out->SetMember(indices, values, x_dims, true); + return; + } + const auto& csr_crows = x.crows(); + const auto& csr_cols = x.cols(); + const auto& csr_values = x.values(); + const IntT* csr_crows_data = csr_crows.data(); + const IntT* csr_cols_data = csr_cols.data(); + const T* csr_values_data = csr_values.data(); + IntT* coo_indices = indices.data(); IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices; IntT* coo_rows_data = @@ -177,7 +181,6 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx, phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D matrix")); const int64_t non_zero_num = x.nnz(); - if (non_zero_num <= 0) return; int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; @@ -185,6 +188,10 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx, phi::DenseTensor crows = phi::Empty(dev_ctx, {batchs * (rows + 1)}); phi::DenseTensor cols = phi::Empty(dev_ctx, {non_zero_num}); phi::DenseTensor values = phi::EmptyLike(dev_ctx, x.values()); + if (non_zero_num <= 0) { + out->SetMember(crows, cols, values, x_dims); + return; + } IntT* csr_crows_data = crows.data(); IntT* csr_cols_data = cols.data(); T* csr_values_data = values.data(); @@ -268,6 +275,12 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx, const T* x_data = values.data(); dev_ctx.template Alloc(out); T* out_data = out->data(); + memset(out_data, 0, sizeof(T) * out->numel()); + + if (x.nnz() <= 0) { + return; + } + int64_t base_offset = 1; for (int64_t i = 0; i < dense_dim; i++) { base_offset *= dense_dims[sparse_dim + i]; @@ -279,7 +292,6 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx, offset *= dense_dims[i]; } - memset(out_data, 0, sizeof(T) * out->numel()); for (auto i = 0; i < non_zero_num; i++) { int64_t index = 0; for (int j = 0; j < sparse_dim; j++) { diff --git a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu index bae969cf23e..3b93ff9638c 100644 --- a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -61,6 +61,13 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx, phi::errors::InvalidArgument("the input x and mask must have the shape")); const DenseTensor& indices = mask.indices(); const DenseTensor& values = mask.values(); + DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); + DenseTensor out_values = phi::EmptyLike(dev_ctx, values); + if (mask.nnz() <= 0) { + out->SetMember(out_indices, out_values, dims, true); + return; + } + const int sparse_dim = mask.sparse_dim(); DenseTensor sparse_offsets = phi::Empty( dev_ctx, @@ -75,9 +82,6 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx, gpuMemcpyHostToDevice, dev_ctx.stream()); - DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); - DenseTensor out_values = phi::EmptyLike(dev_ctx, values); - phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices); const IntT* indices_ptr = indices.data(); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 94fe0570563..084cb0e60bb 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -164,18 +164,20 @@ void DenseToCooKernel(const Context& dev_ctx, T* sparse_data = dev_ctx.template Alloc(&values); // 3. calc indices by indexs and get values by indexs - config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - GetNonZeroElementsAndIndices<<>>(x_data, - sparse_dim, - cols, - d_x_dims.data(), - non_zero_num, - temp_indexs_ptr, - indices_data, - sparse_data); + if (non_zero_num > 0) { + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); + GetNonZeroElementsAndIndices<<>>(x_data, + sparse_dim, + cols, + d_x_dims.data(), + non_zero_num, + temp_indexs_ptr, + indices_data, + sparse_data); + } out->SetMember(indices, values, x_dims, true); } @@ -218,6 +220,21 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx, SparseCooTensor* out) { const DDim& x_dims = x.dims(); const int64_t non_zero_num = x.cols().numel(); + int64_t sparse_dim = 2; + if (x_dims.size() == 3) { + sparse_dim = 3; + } + + if (x.nnz() <= 0) { +#ifdef PADDLE_WITH_HIP + DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); +#else + DenseTensor indices = phi::Empty(dev_ctx, {sparse_dim, non_zero_num}); +#endif + DenseTensor values = phi::EmptyLike(dev_ctx, x.values()); + out->SetMember(indices, values, x_dims, true); + return; + } // rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int') // now @@ -235,10 +252,6 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx, const auto& csr_values = x.values(); const T* csr_values_data = csr_values.data(); - int64_t sparse_dim = 2; - if (x_dims.size() == 3) { - sparse_dim = 3; - } int batches = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; @@ -395,7 +408,6 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx, phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D matrix")); const int64_t non_zero_num = x.nnz(); - if (non_zero_num <= 0) return; int batchs = x_dims.size() == 2 ? 1 : x_dims[0]; int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; @@ -403,6 +415,10 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx, phi::DenseTensor crows = phi::Empty(dev_ctx, {batchs * (rows + 1)}); phi::DenseTensor cols = phi::Empty(dev_ctx, {non_zero_num}); phi::DenseTensor values = phi::EmptyLike(dev_ctx, x.values()); + if (non_zero_num <= 0) { + out->SetMember(crows, cols, values, x_dims); + return; + } IntT* csr_crows_data = crows.data(); IntT* csr_cols_data = cols.data(); T* csr_values_data = values.data(); @@ -503,10 +519,17 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx, const int64_t dense_dim = values.dims().size() - 1; const auto place = dev_ctx.GetPlace(); - const T* x_data = values.data(); dev_ctx.template Alloc(out); T* out_data = out->data(); + phi::backends::gpu::GpuMemsetAsync( + out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream()); + + if (x.nnz() <= 0) { + return; + } + + const T* x_data = values.data(); int64_t base_offset = 1; for (int64_t i = 0; i < dense_dim; i++) { base_offset *= dense_dims[sparse_dim + i]; @@ -525,8 +548,6 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx, sparse_dim * sizeof(int64_t), gpuMemcpyHostToDevice, dev_ctx.stream()); - phi::backends::gpu::GpuMemsetAsync( - out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream()); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py index f4d56d7e41b..60cf3a7a520 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py @@ -351,6 +351,27 @@ class TestSparseConvert(unittest.TestCase): dense_x[2] = 0 verify(dense_x) + def test_zero_nnz(self): + for device in devices: + if device == 'cpu' or ( + device == 'gpu' and paddle.is_compiled_with_cuda() + ): + paddle.device.set_device(device) + x1 = paddle.zeros([2, 2, 2]) + x2 = paddle.zeros([2, 2, 2]) + sp_csr_x = x1.to_sparse_csr() + sp_coo_x = x2.to_sparse_coo(1) + sp_coo_x.stop_gradient = False + + out1 = sp_csr_x.to_dense() + out2 = sp_coo_x.to_dense() + out2.backward() + np.testing.assert_allclose(out1.numpy(), x1.numpy()) + np.testing.assert_allclose(out2.numpy(), x2.numpy()) + np.testing.assert_allclose( + sp_coo_x.grad.to_dense().numpy().sum(), 0.0 + ) + class TestCooError(unittest.TestCase): def test_small_shape(self): -- GitLab