From fd6b1a02435d61bcadf812d9f3b16a91f85f0adf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 13 Jul 2022 16:08:25 +0800 Subject: [PATCH] Add sparse.coalesce (#44256) * add sparse api coalesce --- paddle/phi/api/yaml/sparse_api.yaml | 7 ++++ .../{coalesced_kernel.h => coalesce_kernel.h} | 13 +++++-- ...coalesced_kernel.cc => coalesce_kernel.cc} | 22 ++++++------ ...coalesced_kernel.cu => coalesce_kernel.cu} | 35 +++++++------------ .../phi/kernels/sparse/sparse_utils_kernel.h | 6 ++-- .../kernels/test_sparse_conv3d_dev_api.cc | 7 ++-- .../tests/kernels/test_sparse_pool_dev_api.cc | 6 ++-- .../tests/unittests/test_sparse_conv_op.py | 1 + .../tests/unittests/test_sparse_utils_op.py | 2 ++ python/paddle/incubate/sparse/__init__.py | 2 ++ python/paddle/incubate/sparse/unary.py | 31 ++++++++++++++++ 11 files changed, 88 insertions(+), 44 deletions(-) rename paddle/phi/kernels/sparse/{coalesced_kernel.h => coalesce_kernel.h} (71%) rename paddle/phi/kernels/sparse/cpu/{coalesced_kernel.cc => coalesce_kernel.cc} (87%) rename paddle/phi/kernels/sparse/gpu/{coalesced_kernel.cu => coalesce_kernel.cu} (87%) diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index d8c275ff1f..4c513ed7d2 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -266,6 +266,13 @@ layout : x backward : values_grad +- api: coalesce + args : (Tensor x) + output : Tensor(out) + kernel : + func: coalesce{sparse_coo -> sparse_coo} + layout : x + - api: full_like args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED) output : Tensor(out) diff --git a/paddle/phi/kernels/sparse/coalesced_kernel.h b/paddle/phi/kernels/sparse/coalesce_kernel.h similarity index 71% rename from paddle/phi/kernels/sparse/coalesced_kernel.h rename to paddle/phi/kernels/sparse/coalesce_kernel.h index 0755579a57..cb8b98fd87 100644 --- a/paddle/phi/kernels/sparse/coalesced_kernel.h +++ b/paddle/phi/kernels/sparse/coalesce_kernel.h @@ -22,9 +22,16 @@ namespace phi { namespace sparse { template -void CoalescedKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out); +void CoalesceKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out); + +template +SparseCooTensor Coalesce(const Context& dev_ctx, const SparseCooTensor& x) { + SparseCooTensor coo; + CoalesceKernel(dev_ctx, x, &coo); + return coo; +} } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc b/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc similarity index 87% rename from paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc rename to paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc index 9d1f71afce..95d8abd6bc 100644 --- a/paddle/phi/kernels/sparse/cpu/coalesced_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/coalesce_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/coalesced_kernel.h" +#include "paddle/phi/kernels/sparse/coalesce_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" @@ -22,9 +22,9 @@ namespace phi { namespace sparse { template -void CoalescedCPUKernel(const CPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { const DenseTensor& x_indices = x.non_zero_indices(); const DenseTensor& x_values = x.non_zero_elements(); DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); @@ -95,22 +95,22 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx, } template -void CoalescedKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "CoalescedCPUKernel", ([&] { - CoalescedCPUKernel(dev_ctx, x, out); + x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] { + CoalesceCPUKernel(dev_ctx, x, out); })); } } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sort, +PD_REGISTER_KERNEL(coalesce, CPU, ALL_LAYOUT, - phi::sparse::CoalescedKernel, + phi::sparse::CoalesceKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu similarity index 87% rename from paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu rename to paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index 405384009d..f6aedb8b68 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/coalesced_kernel.h" +#include "paddle/phi/kernels/sparse/coalesce_kernel.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" @@ -27,9 +27,9 @@ namespace phi { namespace sparse { template -void CoalescedGPUKernel(const GPUContext& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { const DenseTensor& x_indices = x.non_zero_indices(); const DenseTensor& x_values = x.non_zero_elements(); DenseTensor out_indices = phi::EmptyLike(dev_ctx, x_indices); @@ -55,11 +55,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), sparse_offsets.data(), sizeof(IntT) * sparse_dim, -#ifdef PADDLE_WITH_HIP - hipMemcpyHostToDevice, -#else - cudaMemcpyHostToDevice, -#endif + gpuMemcpyHostToDevice, dev_ctx.stream()); // 1. flatten indices @@ -117,11 +113,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(&out_nnz, out_indices.data(), sizeof(IntT), -#ifdef PADDLE_WITH_HIP - hipMemcpyDeviceToHost, -#else - cudaMemcpyDeviceToHost, -#endif + gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); @@ -161,22 +153,21 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, } template -void CoalescedKernel(const Context& dev_ctx, - const SparseCooTensor& x, - SparseCooTensor* out) { +void CoalesceKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { PD_VISIT_INTEGRAL_TYPES( - x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] { - CoalescedGPUKernel(dev_ctx, x, out); + x.non_zero_indices().dtype(), "CoalesceGPUKernel", ([&] { + CoalesceGPUKernel(dev_ctx, x, out); })); } - } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sort, +PD_REGISTER_KERNEL(coalesce, GPU, ALL_LAYOUT, - phi::sparse::CoalescedKernel, + phi::sparse::CoalesceKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 93abf70b24..12d55596a9 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/sparse/coalesced_kernel.h" namespace phi { namespace sparse { @@ -154,9 +153,8 @@ void SparseCooTensorKernel(const Context& dev_ctx, const DenseTensor& indices, const IntArray& dense_shape, SparseCooTensor* out) { - SparseCooTensor before_coalesced( - indices, values, phi::make_ddim(dense_shape.GetData())); - CoalescedKernel(dev_ctx, before_coalesced, out); + *out = + SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData())); } } // namespace sparse diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index f08c7b0872..2efdd47998 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/sparse/coalesce_kernel.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" @@ -207,6 +208,8 @@ void TestConv3dBase(const std::vector& indices, subm, &d_rulebook); + SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); + ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); for (int i = 0; i < correct_out_dims.size(); i++) { @@ -217,7 +220,7 @@ void TestConv3dBase(const std::vector& indices, dev_ctx_cpu, DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, - d_out.non_zero_indices(), + tmp_d_out.non_zero_indices(), phi::CPUPlace(), true, &h_indices_tensor); @@ -231,7 +234,7 @@ void TestConv3dBase(const std::vector& indices, phi::EmptyLike(dev_ctx_cpu, d_out.non_zero_elements()); phi::Copy(dev_ctx_gpu, - d_out.non_zero_elements(), + tmp_d_out.non_zero_elements(), phi::CPUPlace(), true, &h_features_tensor); diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index 460dca59c7..eeba9cdc13 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/sparse/coalesce_kernel.h" #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" @@ -157,6 +158,7 @@ void TestMaxPoolBase(const std::vector& indices, dilations, strides, &d_rulebook); + SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); @@ -168,7 +170,7 @@ void TestMaxPoolBase(const std::vector& indices, dev_ctx_cpu, DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, - d_out.non_zero_indices(), + tmp_d_out.non_zero_indices(), phi::CPUPlace(), true, &h_indices_tensor); @@ -182,7 +184,7 @@ void TestMaxPoolBase(const std::vector& indices, phi::EmptyLike(dev_ctx_cpu, d_out.non_zero_elements()); phi::Copy(dev_ctx_gpu, - d_out.non_zero_elements(), + tmp_d_out.non_zero_elements(), phi::CPUPlace(), true, &h_features_tensor); diff --git a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py index e1a9b2428b..9501b2c895 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py @@ -53,6 +53,7 @@ class TestSparseConv(unittest.TestCase): groups=1, data_format="NDHWC") out.backward(out) + out = paddle.incubate.sparse.coalesce(out) assert np.array_equal(correct_out_values, out.values().numpy()) def test_subm_conv3d(self): diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py index ac69469cbb..53c84c9d1f 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py @@ -298,6 +298,7 @@ class TestSparseConvert(unittest.TestCase): values = paddle.to_tensor(values, dtype='float32') sparse_x = paddle.incubate.sparse.sparse_coo_tensor( indices, values) + sparse_x = paddle.incubate.sparse.coalesce(sparse_x) indices_sorted = [[0, 1], [1, 0]] values_sorted = [5.0, 1.0] assert np.array_equal(indices_sorted, @@ -310,6 +311,7 @@ class TestSparseConvert(unittest.TestCase): values = paddle.to_tensor(values, dtype='float32') sparse_x = paddle.incubate.sparse.sparse_coo_tensor( indices, values) + sparse_x = paddle.incubate.sparse.coalesce(sparse_x) values_sorted = [[5.0, 5.0], [1.0, 1.0]] assert np.array_equal(indices_sorted, sparse_x.indices().numpy()) diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index c56ada3468..47c7a312e2 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -30,6 +30,7 @@ from .unary import abs from .unary import pow from .unary import cast from .unary import neg +from .unary import coalesce from .binary import mv from .binary import matmul @@ -66,4 +67,5 @@ __all__ = [ 'subtract', 'multiply', 'divide', + 'coalesce', ] diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py index d3fb55b737..1725c8791f 100644 --- a/python/paddle/incubate/sparse/unary.py +++ b/python/paddle/incubate/sparse/unary.py @@ -472,3 +472,34 @@ def abs(x, name=None): """ return _C_ops.final_state_sparse_abs(x) + + +@dygraph_only +def coalesce(x): + r""" + the coalesced operator include sorted and merge, after coalesced, the indices of x is sorted and unique. + + Parameters: + x (Tensor): the input SparseCooTensor. + + Returns: + Tensor: return the SparseCooTensor after coalesced. + + Examples: + .. code-block:: python + + import paddle + from paddle.incubate import sparse + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + indices = [[0, 0, 1], [1, 1, 2]] + values = [1.0, 2.0, 3.0] + sp_x = sparse.sparse_coo_tensor(indices, values) + sp_x = sparse.coalesce(sp_x) + print(sp_x.indices()) + #[[0, 1], [1, 2]] + print(sp_x.values()) + #[3.0, 3.0] + """ + return _C_ops.final_state_sparse_coalesce(x) -- GitLab