未验证 提交 fd6b1a02 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add sparse.coalesce (#44256)


* add sparse api coalesce
上级 77c010a0
...@@ -266,6 +266,13 @@ ...@@ -266,6 +266,13 @@
layout : x layout : x
backward : values_grad backward : values_grad
- api: coalesce
args : (Tensor x)
output : Tensor(out)
kernel :
func: coalesce{sparse_coo -> sparse_coo}
layout : x
- api: full_like - api: full_like
args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED) args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED)
output : Tensor(out) output : Tensor(out)
......
...@@ -22,9 +22,16 @@ namespace phi { ...@@ -22,9 +22,16 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename Context> template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx, void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out); SparseCooTensor* out);
template <typename T, typename Context>
SparseCooTensor Coalesce(const Context& dev_ctx, const SparseCooTensor& x) {
SparseCooTensor coo;
CoalesceKernel<T, Context>(dev_ctx, x, &coo);
return coo;
}
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/sparse/coalesced_kernel.h" #include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h" #include "paddle/phi/core/visit_type.h"
...@@ -22,9 +22,9 @@ namespace phi { ...@@ -22,9 +22,9 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename IntT> template <typename T, typename IntT>
void CoalescedCPUKernel(const CPUContext& dev_ctx, void CoalesceCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices(); const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements(); const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices); DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
...@@ -95,22 +95,22 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx, ...@@ -95,22 +95,22 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx, void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedCPUKernel", ([&] { x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] {
CoalescedCPUKernel<T, data_t>(dev_ctx, x, out); CoalesceCPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
} }
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sort, PD_REGISTER_KERNEL(coalesce,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CoalescedKernel, phi::sparse::CoalesceKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/sparse/coalesced_kernel.h" #include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
...@@ -27,9 +27,9 @@ namespace phi { ...@@ -27,9 +27,9 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename IntT> template <typename T, typename IntT>
void CoalescedGPUKernel(const GPUContext& dev_ctx, void CoalesceGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices(); const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements(); const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices); DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
...@@ -55,11 +55,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, ...@@ -55,11 +55,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(), phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(), sparse_offsets.data(),
sizeof(IntT) * sparse_dim, sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP gpuMemcpyHostToDevice,
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream()); dev_ctx.stream());
// 1. flatten indices // 1. flatten indices
...@@ -117,11 +113,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, ...@@ -117,11 +113,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(&out_nnz, phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
out_indices.data<IntT>(), out_indices.data<IntT>(),
sizeof(IntT), sizeof(IntT),
#ifdef PADDLE_WITH_HIP gpuMemcpyDeviceToHost,
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait(); dev_ctx.Wait();
...@@ -161,22 +153,21 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, ...@@ -161,22 +153,21 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx, void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
SparseCooTensor* out) { SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES( PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] { x.non_zero_indices().dtype(), "CoalesceGPUKernel", ([&] {
CoalescedGPUKernel<T, data_t>(dev_ctx, x, out); CoalesceGPUKernel<T, data_t>(dev_ctx, x, out);
})); }));
} }
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sort, PD_REGISTER_KERNEL(coalesce,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::sparse::CoalescedKernel, phi::sparse::CoalesceKernel,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
...@@ -154,9 +153,8 @@ void SparseCooTensorKernel(const Context& dev_ctx, ...@@ -154,9 +153,8 @@ void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& indices, const DenseTensor& indices,
const IntArray& dense_shape, const IntArray& dense_shape,
SparseCooTensor* out) { SparseCooTensor* out) {
SparseCooTensor before_coalesced( *out =
indices, values, phi::make_ddim(dense_shape.GetData())); SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
CoalescedKernel<T, Context>(dev_ctx, before_coalesced, out);
} }
} // namespace sparse } // namespace sparse
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h"
...@@ -207,6 +208,8 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -207,6 +208,8 @@ void TestConv3dBase(const std::vector<IntT>& indices,
subm, subm,
&d_rulebook); &d_rulebook);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz());
for (int i = 0; i < correct_out_dims.size(); i++) { for (int i = 0; i < correct_out_dims.size(); i++) {
...@@ -217,7 +220,7 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -217,7 +220,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
dev_ctx_cpu, dev_ctx_cpu,
DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW));
phi::Copy(dev_ctx_gpu, phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(), tmp_d_out.non_zero_indices(),
phi::CPUPlace(), phi::CPUPlace(),
true, true,
&h_indices_tensor); &h_indices_tensor);
...@@ -231,7 +234,7 @@ void TestConv3dBase(const std::vector<IntT>& indices, ...@@ -231,7 +234,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements()); phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements());
phi::Copy(dev_ctx_gpu, phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(), tmp_d_out.non_zero_elements(),
phi::CPUPlace(), phi::CPUPlace(),
true, true,
&h_features_tensor); &h_features_tensor);
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" #include "paddle/phi/kernels/sparse/sparse_pool_kernel.h"
...@@ -157,6 +158,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -157,6 +158,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
dilations, dilations,
strides, strides,
&d_rulebook); &d_rulebook);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);
ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz());
...@@ -168,7 +170,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -168,7 +170,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
dev_ctx_cpu, dev_ctx_cpu,
DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW));
phi::Copy(dev_ctx_gpu, phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(), tmp_d_out.non_zero_indices(),
phi::CPUPlace(), phi::CPUPlace(),
true, true,
&h_indices_tensor); &h_indices_tensor);
...@@ -182,7 +184,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices, ...@@ -182,7 +184,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements()); phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements());
phi::Copy(dev_ctx_gpu, phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(), tmp_d_out.non_zero_elements(),
phi::CPUPlace(), phi::CPUPlace(),
true, true,
&h_features_tensor); &h_features_tensor);
......
...@@ -53,6 +53,7 @@ class TestSparseConv(unittest.TestCase): ...@@ -53,6 +53,7 @@ class TestSparseConv(unittest.TestCase):
groups=1, groups=1,
data_format="NDHWC") data_format="NDHWC")
out.backward(out) out.backward(out)
out = paddle.incubate.sparse.coalesce(out)
assert np.array_equal(correct_out_values, out.values().numpy()) assert np.array_equal(correct_out_values, out.values().numpy())
def test_subm_conv3d(self): def test_subm_conv3d(self):
......
...@@ -298,6 +298,7 @@ class TestSparseConvert(unittest.TestCase): ...@@ -298,6 +298,7 @@ class TestSparseConvert(unittest.TestCase):
values = paddle.to_tensor(values, dtype='float32') values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.incubate.sparse.sparse_coo_tensor( sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
indices, values) indices, values)
sparse_x = paddle.incubate.sparse.coalesce(sparse_x)
indices_sorted = [[0, 1], [1, 0]] indices_sorted = [[0, 1], [1, 0]]
values_sorted = [5.0, 1.0] values_sorted = [5.0, 1.0]
assert np.array_equal(indices_sorted, assert np.array_equal(indices_sorted,
...@@ -310,6 +311,7 @@ class TestSparseConvert(unittest.TestCase): ...@@ -310,6 +311,7 @@ class TestSparseConvert(unittest.TestCase):
values = paddle.to_tensor(values, dtype='float32') values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.incubate.sparse.sparse_coo_tensor( sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
indices, values) indices, values)
sparse_x = paddle.incubate.sparse.coalesce(sparse_x)
values_sorted = [[5.0, 5.0], [1.0, 1.0]] values_sorted = [[5.0, 5.0], [1.0, 1.0]]
assert np.array_equal(indices_sorted, assert np.array_equal(indices_sorted,
sparse_x.indices().numpy()) sparse_x.indices().numpy())
......
...@@ -30,6 +30,7 @@ from .unary import abs ...@@ -30,6 +30,7 @@ from .unary import abs
from .unary import pow from .unary import pow
from .unary import cast from .unary import cast
from .unary import neg from .unary import neg
from .unary import coalesce
from .binary import mv from .binary import mv
from .binary import matmul from .binary import matmul
...@@ -66,4 +67,5 @@ __all__ = [ ...@@ -66,4 +67,5 @@ __all__ = [
'subtract', 'subtract',
'multiply', 'multiply',
'divide', 'divide',
'coalesce',
] ]
...@@ -472,3 +472,34 @@ def abs(x, name=None): ...@@ -472,3 +472,34 @@ def abs(x, name=None):
""" """
return _C_ops.final_state_sparse_abs(x) return _C_ops.final_state_sparse_abs(x)
@dygraph_only
def coalesce(x):
r"""
the coalesced operator include sorted and merge, after coalesced, the indices of x is sorted and unique.
Parameters:
x (Tensor): the input SparseCooTensor.
Returns:
Tensor: return the SparseCooTensor after coalesced.
Examples:
.. code-block:: python
import paddle
from paddle.incubate import sparse
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
indices = [[0, 0, 1], [1, 1, 2]]
values = [1.0, 2.0, 3.0]
sp_x = sparse.sparse_coo_tensor(indices, values)
sp_x = sparse.coalesce(sp_x)
print(sp_x.indices())
#[[0, 1], [1, 2]]
print(sp_x.values())
#[3.0, 3.0]
"""
return _C_ops.final_state_sparse_coalesce(x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册