From e8e3b9976e04f08f89fb439fc408f83583eb07e7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 5 May 2022 13:36:54 +0800 Subject: [PATCH] fix sparse mask (#42305) --- paddle/phi/core/sparse_coo_tensor.cc | 8 ++++++++ paddle/phi/core/sparse_coo_tensor.h | 6 ++++++ .../phi/kernels/sparse/cpu/sparse_mask_kernel.cc | 4 ++-- .../sparse/cpu/sparse_pool_grad_kernel.cc | 2 +- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 2 +- .../phi/kernels/sparse/gpu/sparse_mask_kernel.cu | 13 +++++-------- .../sparse/gpu/sparse_pool_grad_kernel.cu | 2 +- .../tests/unittests/test_sparse_pooling_op.py | 16 +++++++++++----- 8 files changed, 35 insertions(+), 18 deletions(-) diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index 7d4261ef82..bf4d601c0b 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -115,4 +115,12 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices, this->coalesced_ = coalesced; } +int32_t SparseCooTensor::sparse_dim() const { + return non_zero_indices_.dims()[0]; +} + +int32_t SparseCooTensor::dense_dim() const { + return dims_.size() - sparse_dim(); +} + } // namespace phi diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index ec43c5d621..c65b5ce574 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -150,6 +150,12 @@ class SparseCooTensor : public TensorBase, /// \brief set the dims of original dense tensor void set_dims(const DDim& dims) { this->dims_ = dims; } + /// \brief get the sparse dim + int32_t sparse_dim() const; + + /// \brief get the dnese dim + int32_t dense_dim() const; + private: // save the indices of non zero elements in original dense tensor DenseTensor non_zero_indices_; diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index 0ec8b808ba..0e5714b174 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -39,7 +39,7 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, phi::errors::InvalidArgument("the input x and mask must have the shape")); const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& values = mask.non_zero_elements(); - int sparse_dim = indices.dims().size(); + const int sparse_dim = mask.sparse_dim(); DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); DenseTensor out_values = phi::EmptyLike(dev_ctx, values); @@ -95,7 +95,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, 2, phi::errors::InvalidArgument("the mask_indices must be 2-D tensor")); - const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + const int32_t sparse_dim = x.sparse_dim(); std::vector sparse_offsets(sparse_dim), x_indexs(x.nnz()), mask_indexs(mask_indices.dims()[1]); diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc index 78b6354f44..71a0095395 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -50,7 +50,7 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx, DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); T* x_grad_ptr = x_grad_values.data(); - memset(x_grad_ptr, 0, sizeof(T) * x_grad->numel()); + memset(x_grad_ptr, 0, sizeof(T) * x_grad_values.numel()); phi::Copy(dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 685aa6b30b..69ac0417f7 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -254,7 +254,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx, if (indices_dims.size() == 1) { sparse_dim = 1; } - const int64_t dense_dim = values.dims().size() - 1; + const int64_t dense_dim = x.dense_dim(); const T* x_data = values.data(); *out = phi::Empty( diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index 4253845956..81c63c48eb 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -42,7 +42,7 @@ __global__ void MaskKernel(const T* x_ptr, int64_t col_i = i - out_i * cols; int64_t index = 0; for (int j = 0; j < sparse_dim; j++) { - index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j]; + index += indices_ptr[j * non_zero_num + out_i] * sparse_offsets[j]; } out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i]; } @@ -60,16 +60,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, phi::errors::InvalidArgument("the input x and mask must have the shape")); const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& values = mask.non_zero_elements(); - int sparse_dim = indices.dims().size(); + const int sparse_dim = mask.sparse_dim(); DenseTensor sparse_offsets = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW)); std::vector h_sparse_offsets(sparse_dim); - int64_t offset = 1; - for (int i = sparse_dim - 1; i >= 0; i--) { - h_sparse_offsets[i] = offset; - offset *= dims[i]; - } + phi::funcs::sparse::CalcOffsetsPerDim( + dims, sparse_dim, h_sparse_offsets.data()); phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data(), &h_sparse_offsets[0], @@ -151,7 +148,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, 2, phi::errors::InvalidArgument("the mask_indices must be 2-D tensor")); - const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + const int32_t sparse_dim = x.sparse_dim(); auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); std::vector sparse_offsets(sparse_dim); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu index bd862a44af..c22e67eef6 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -64,7 +64,7 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx, int rulebook_len = rulebook.dims()[1]; const IntT* rulebook_ptr = rulebook.data(); std::vector offsets(kernel_size + 1), counter(kernel_size, 0), - h_counter(kernel_size); + h_counter(rulebook_len, 0); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], rulebook_ptr, rulebook_len * sizeof(IntT), diff --git a/python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py b/python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py index a1a3849f71..8d65a4c444 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py @@ -19,6 +19,7 @@ import paddle import paddle.fluid.core as core from paddle import _C_ops from paddle.fluid.framework import _test_eager_guard +import copy class TestMaxPool3DFunc(unittest.TestCase): @@ -44,23 +45,28 @@ class TestMaxPool3DFunc(unittest.TestCase): def test(self): with _test_eager_guard(): self.setUp() + self.dense_x.stop_gradient = False sparse_x = self.dense_x.to_sparse_coo(4) - out = paddle.sparse.functional.max_pool3d( + sparse_out = paddle.sparse.functional.max_pool3d( sparse_x, self.kernel_sizes, stride=self.strides, padding=self.paddings) - out = out.to_dense() + out = sparse_out.to_dense() + out.backward(out) + dense_x = copy.deepcopy(self.dense_x) dense_out = paddle.nn.functional.max_pool3d( - self.dense_x, + dense_x, self.kernel_sizes, stride=self.strides, padding=self.paddings, data_format='NDHWC') + dense_out.backward(dense_out) + #compare with dense - assert np.allclose(dense_out.flatten().numpy(), - out.flatten().numpy()) + assert np.allclose(dense_out.numpy(), out.numpy()) + assert np.allclose(dense_x.grad.numpy(), self.dense_x.grad.numpy()) class TestStride(TestMaxPool3DFunc): -- GitLab