From 5fef043dc0950c2884ad538303f8f56ee3b1c86f Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Tue, 18 Oct 2022 19:51:24 +0800 Subject: [PATCH] [cherry-pick 2.4] add sparse api transpose/reshape/is_same_shape (#47076) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增sparse.is_same_shape、sparse.reshape、sparse.transpose 三个API --- paddle/fluid/pybind/eager_method.cc | 13 + paddle/phi/api/yaml/sparse_backward.yaml | 22 ++ paddle/phi/api/yaml/sparse_ops.yaml | 23 ++ paddle/phi/core/sparse_coo_tensor.h | 2 +- paddle/phi/core/sparse_csr_tensor.h | 4 +- .../kernels/sparse/cpu/reshape_grad_kernel.cc | 73 ++++ .../phi/kernels/sparse/cpu/reshape_kernel.cc | 117 ++++++ .../kernels/sparse/cpu/sparse_utils_kernel.cc | 6 +- .../sparse/cpu/transpose_grad_kernel.cc | 78 ++++ .../kernels/sparse/cpu/transpose_kernel.cc | 231 ++++++++++++ .../kernels/sparse/gpu/reshape_grad_kernel.cu | 77 ++++ .../phi/kernels/sparse/gpu/reshape_kernel.cu | 165 +++++++++ .../kernels/sparse/gpu/sparse_utils_kernel.cu | 6 +- .../sparse/gpu/transpose_grad_kernel.cu | 80 +++++ .../kernels/sparse/gpu/transpose_kernel.cu | 338 ++++++++++++++++++ paddle/phi/kernels/sparse/unary_grad_kernel.h | 24 ++ paddle/phi/kernels/sparse/unary_kernel.h | 82 +++++ paddle/phi/tests/kernels/CMakeLists.txt | 4 + .../kernels/test_sparse_transpose_dev_api.cc | 165 +++++++++ .../unittests/test_sparse_is_same_shape.py | 125 +++++++ .../tests/unittests/test_sparse_reshape_op.py | 136 +++++++ .../unittests/test_sparse_transpose_op.py | 77 ++++ python/paddle/incubate/sparse/__init__.py | 37 +- python/paddle/incubate/sparse/binary.py | 33 ++ python/paddle/incubate/sparse/unary.py | 88 +++++ 25 files changed, 1970 insertions(+), 36 deletions(-) create mode 100644 paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc create mode 100644 paddle/phi/kernels/sparse/cpu/reshape_kernel.cc create mode 100644 paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc create mode 100644 paddle/phi/kernels/sparse/cpu/transpose_kernel.cc create mode 100644 paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/reshape_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/transpose_kernel.cu create mode 100644 paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_is_same_shape.py create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index c782b4df585..8649f88d48d 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1588,6 +1588,15 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* tensor_method_is_same_shape(TensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0); + return ToPyObject(self->tensor.shape() == other.shape()); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -1983,6 +1992,10 @@ PyMethodDef variable_methods[] = { (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr, METH_VARARGS | METH_KEYWORDS, NULL}, + {"is_same_shape", + (PyCFunction)(void (*)(void))tensor_method_is_same_shape, + METH_VARARGS | METH_KEYWORDS, + NULL}, {"to_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 8347ee200e8..40b646cb389 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -260,6 +260,17 @@ func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} +- backward_op : reshape_grad + forward : reshape(Tensor x, IntArray shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + - backward_op : scale_grad forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, float scale) @@ -385,6 +396,17 @@ kernel : func : coo_to_dense { sparse_coo -> dense } +- backward_op : transpose_grad + forward : transpose(Tensor x, int[] perm) -> Tensor(out) + args : (Tensor out_grad, int[] perm) + output : Tensor(x_grad) + infer_meta : + func : TransposeGradInferMeta + param : [out_grad, perm] + kernel : + func : transpose_coo_grad {sparse_coo -> sparse_coo}, + transpose_csr_grad {sparse_csr -> sparse_csr} + - backward_op : values_grad forward : values_coo(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index a917012b2f7..c6ad1bfa583 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -457,3 +457,26 @@ mv_csr{sparse_csr, dense -> dense} layout : x backward: mv_grad + +- op : transpose + args : (Tensor x, int[] perm) + output : Tensor(out) + infer_meta : + func : TransposeInferMeta + param: [ x, perm ] + kernel : + func : transpose_coo{sparse_coo -> sparse_coo}, + transpose_csr{sparse_csr -> sparse_csr} + layout : x + backward : transpose_grad + +- op : reshape + args : (Tensor x, IntArray shape) + output : Tensor(out) + infer_meta : + func : ReshapeInferMeta + kernel : + func : reshape_coo{sparse_coo -> sparse_coo}, + reshape_csr{sparse_csr -> sparse_csr} + layout : x + backward : reshape_grad diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index f8869aa524d..a28229996c8 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase, [0, 0, 0, 0]] dims_ = (4, 4) non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]] - non_zero_indices_ = [0, 2], + non_zero_indices_ = [[0, 2], [1, 2]] */ }; diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 056d049942a..2acb35915a9 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase, [0, 0, 4, 0], [0, 5, 0, 6]] dims_ = (4, 4) - non_zero_elements_ = [1, 2, 3, 4, 5 ,6] + non_zero_elements_ = [1, 2, 3, 4, 5, 6] non_zero_crows_ = [0, 1, 3, 4, 6] non_zero_cols_ = [1, 0, 3, 2, 1, 3] */ @@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase, [0, 0, 4, 0], [0, 5, 0, 0]]] dims_ = (2, 4, 4) - non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5] + non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5] non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5] non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1] */ diff --git a/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc new file mode 100644 index 00000000000..fc843f81c31 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCooKernel(dev_ctx, dout, x_shape, dx); +} + +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCsrKernel(dev_ctx, dout, x_shape, dx); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc b/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc new file mode 100644 index 00000000000..4f165156668 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" +#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" + +namespace phi { +namespace sparse { + +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out) { + // TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims + int64_t x_nnz = x.nnz(); + + // Use DDim::reshape to handle -1 and 0 in the argument "shape" + std::vector new_shape(shape.GetData().begin(), shape.GetData().end()); + phi::DDim out_dims = x.dims().reshape(new_shape); + // get sparse part dimensions of x and out + std::vector x_sparse_part_dims; + std::vector out_sparse_part_dims; + for (int i = 0; i < x.sparse_dim(); ++i) { + x_sparse_part_dims.push_back(x.dims()[i]); + } + for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { + out_sparse_part_dims.push_back(out_dims[i]); + } + DenseTensor out_indices = Empty( + dev_ctx, {static_cast(out_sparse_part_dims.size()), x_nnz}); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor& x_indices = x.indices(); + const auto* x_indices_data = x_indices.data(); + auto* out_indices_data = out_indices.data(); + + const phi::DDim& x_sparse_part_strides = + phi::stride(phi::make_ddim(x_sparse_part_dims)); + const phi::DDim& out_sparse_part_strides = + phi::stride(phi::make_ddim(out_sparse_part_dims)); + int64_t location = 0; + for (int64_t j = 0; j < x_nnz; ++j) { + location = 0; + for (int i = 0; i < x.sparse_dim(); ++i) { + location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; + } + for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) { + out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; + location %= out_sparse_part_strides[i]; + } + } +} + +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out) { + // transform csr format to coo format, and then use coo kernel + const SparseCooTensor x_coo = CsrToCoo(dev_ctx, x); + SparseCooTensor out_coo; + ReshapeCooKernel(dev_ctx, x_coo, shape, &out_coo); + CooToCsrKernel(dev_ctx, out_coo, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index d0016099cd7..dcb4399aa28 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(coo_to_csr, CPU, @@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(dense_to_csr, CPU, diff --git a/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc new file mode 100644 index 00000000000..87822a9375e --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +std::vector get_cpu_grad_perm(std::vector perm) { + std::vector grad_perm(perm.size()); + for (unsigned int i = 0; i < perm.size(); ++i) { + grad_perm[perm[i]] = i; + } + return grad_perm; +} + +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx) { + std::vector grad_perm = get_cpu_grad_perm(perm); + TransposeCooKernel(dev_ctx, dout, grad_perm, dx); +} + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx) { + std::vector grad_perm = get_cpu_grad_perm(perm); + TransposeCsrKernel(dev_ctx, dout, grad_perm, dx); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCooGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc b/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc new file mode 100644 index 00000000000..6dea63ffbce --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void TransposeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm, + SparseCooTensor* out) { + // create out sparse tensor + int64_t x_nnz = x.nnz(); + DDim out_dims = x.dims().transpose(perm); + DenseTensor out_indices = EmptyLike(dev_ctx, x.indices()); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor& x_indices = x.indices(); + const auto* x_indices_data = x_indices.data(); + auto* out_indices_data = out_indices.data(); + for (unsigned int i = 0; i < perm.size(); ++i) { + for (int64_t j = 0; j < x_nnz; ++j) { + out_indices_data[j + i * x_nnz] = x_indices_data[j + perm[i] * x_nnz]; + } + } +} + +template +void TransposeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm, + SparseCsrTensor* out) { + unsigned int n_dim = perm.size(); + const DenseTensor& x_crows = x.crows(); + const DenseTensor& x_cols = x.cols(); + const DenseTensor& x_values = x.values(); + DenseTensor out_crows, out_cols, out_values; + // return a copy of x + if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) { + out_crows = x_crows; + out_cols = x_cols; + out_values = x_values; + out->SetMember(out_crows, out_cols, out_values, x.dims()); + return; + } + // create out sparse tensor + DDim out_dims = x.dims().transpose(perm); + if (n_dim == 2) { + out_crows = Empty(dev_ctx, {out_dims[0] + 1}); + } else { + out_crows = + Empty(dev_ctx, {out_dims[0] * (out_dims[1] + 1)}); + } + out_cols = EmptyLike(dev_ctx, x.cols()); + out_values = EmptyLike(dev_ctx, x.values()); + out->SetMember(out_crows, out_cols, out_values, out_dims); + // transpose by two stages + if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {0, 2, 1}, out); + return; + } else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &temp); + TransposeCsrKernel(dev_ctx, temp, {1, 0, 2}, out); + return; + } else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {2, 0, 1}, out); + return; + } + + int64_t* out_crows_data = out_crows.data(); + int64_t* out_cols_data = out_cols.data(); + T* out_values_data = out_values.data(); + const int64_t* x_crows_data = x_crows.data(); + const int64_t* x_cols_data = x_cols.data(); + const T* x_values_data = x_values.data(); + + int64_t x_nnz = x.nnz(); + if (n_dim == 2) { // perm == {1, 0} + // compute out_crows_data by x_cols_data + for (int i = 0; i < out_dims[0]; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_nnz; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 1]++; + } + out_crows_data[out_dims[0]] = x_nnz; + for (int i = 1; i < out_dims[0]; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + std::unordered_map cols_offset; + for (int i = 0; i < x.dims()[0]; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j]; + int64_t jjj = out_crows_data[x_cols_j]; + if (cols_offset.count(jjj)) { + cols_offset[jjj]++; + } else { + cols_offset[jjj] = 0; + } + int64_t jjj_offset = jjj + cols_offset[jjj]; + out_cols_data[jjj_offset] = i; + out_values_data[jjj_offset] = x_values_data[j]; + } + } + } else { // n_dim == 3 + int out_n_rows = out_dims[1]; + int x_n_rows = x.dims()[1]; + for (int k = 0; k < out_dims[0]; ++k) { + if (perm[0] == 0) { // perm == {0, 2, 1} + // compute out_crows_data by x_cols_data + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_crows_data[x_n_rows]; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 1]++; + } + out_crows_data[out_n_rows] = x_crows_data[x_n_rows]; + for (int i = 1; i < out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + std::unordered_map cols_offset; + for (int i = 0; i < x_n_rows; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j]; + int64_t jjj = out_crows_data[x_cols_j]; + if (cols_offset.count(jjj)) { + cols_offset[jjj]++; + } else { + cols_offset[jjj] = 0; + } + int64_t jjj_offset = jjj + cols_offset[jjj]; + out_cols_data[jjj_offset] = i; + out_values_data[jjj_offset] = x_values_data[j]; + } + } + // x offset + x_cols_data += x_crows_data[x_n_rows]; + x_values_data += x_crows_data[x_n_rows]; + x_crows_data += x_n_rows + 1; + } else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2} + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + int x_cols_offset = 0; + int out_cols_index = 0; + for (int i = 0; i < x.dims()[0]; ++i) { + int x_crows_index = i * (x_n_rows + 1); + int start = x_crows_data[x_crows_index + k]; + int end = x_crows_data[x_crows_index + 1 + k]; + out_crows_data[i + 1] = end - start; + for (int j = start; j < end; ++j) { + out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j]; + out_values_data[out_cols_index] = x_values_data[x_cols_offset + j]; + out_cols_index++; + } + x_cols_offset += x_crows_data[x_crows_index + x_n_rows]; + } + for (int i = 1; i <= out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + } + // out offset + out_cols_data += out_crows_data[out_n_rows]; + out_values_data += out_crows_data[out_n_rows]; + out_crows_data += out_n_rows + 1; + } + } +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu new file mode 100644 index 00000000000..bfc81676eb8 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCooKernel(dev_ctx, dout, x_shape, dx); +} + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCsrKernel(dev_ctx, dout, x_shape, dx); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu b/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu new file mode 100644 index 00000000000..6e3a9842e8c --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" + +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace sparse { + +__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data, + const int num_x_sparse_part_dims, + const int num_out_sparse_part_dims, + const int64_t x_nnz, + const int64_t* x_sparse_part_strides, + const int64_t* out_sparse_part_strides, + int64_t* out_indices_data) { + CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) { + int64_t location = 0; + for (int i = 0; i < num_x_sparse_part_dims; ++i) { + location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; + } + for (int i = 0; i < num_out_sparse_part_dims; ++i) { + out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; + location %= out_sparse_part_strides[i]; + } + } +} + +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out) { + int64_t x_nnz = x.nnz(); + std::vector new_shape(shape.GetData().begin(), shape.GetData().end()); + phi::DDim out_dims = x.dims().reshape(new_shape); + // get sparse part dimensions of x and out + std::vector x_sparse_part_dims; + std::vector out_sparse_part_dims; + for (int i = 0; i < x.sparse_dim(); ++i) { + x_sparse_part_dims.push_back(x.dims()[i]); + } + for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { + out_sparse_part_dims.push_back(out_dims[i]); + } + + DenseTensor out_indices = Empty( + dev_ctx, {static_cast(out_sparse_part_dims.size()), x_nnz}); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of out indices + const auto* x_indices_data = x.indices().data(); + auto* out_indices_data = out_indices.data(); + const phi::DDim& x_sparse_part_strides = + phi::stride(phi::make_ddim(x_sparse_part_dims)); + const phi::DDim& out_sparse_part_strides = + phi::stride(phi::make_ddim(out_sparse_part_dims)); + + int64_t *destination_x_sparse_part_strides, + *destination_out_sparse_part_strides; + +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&destination_x_sparse_part_strides), + sizeof(int64_t) * x_sparse_part_strides.size()); + hipMemcpy(destination_x_sparse_part_strides, + x_sparse_part_strides.Get(), + sizeof(int64_t) * x_sparse_part_strides.size(), + hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&destination_out_sparse_part_strides), + sizeof(int64_t) * out_sparse_part_strides.size()); + hipMemcpy(destination_out_sparse_part_strides, + out_sparse_part_strides.Get(), + sizeof(int64_t) * out_sparse_part_strides.size(), + hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&destination_x_sparse_part_strides), + sizeof(int64_t) * x_sparse_part_strides.size()); + cudaMemcpy(destination_x_sparse_part_strides, + x_sparse_part_strides.Get(), + sizeof(int64_t) * x_sparse_part_strides.size(), + cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&destination_out_sparse_part_strides), + sizeof(int64_t) * out_sparse_part_strides.size()); + cudaMemcpy(destination_out_sparse_part_strides, + out_sparse_part_strides.Get(), + sizeof(int64_t) * out_sparse_part_strides.size(), + cudaMemcpyHostToDevice); +#endif + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1); + ReshapeCooCudaKernel<<>>( + x_indices_data, + x_sparse_part_dims.size(), + out_sparse_part_dims.size(), + x_nnz, + destination_x_sparse_part_strides, + destination_out_sparse_part_strides, + out_indices_data); +} + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out) { + // transform csr format to coo format, and then use coo kernel + const SparseCooTensor x_coo = CsrToCoo(dev_ctx, x); + SparseCooTensor out_coo; + ReshapeCooKernel(dev_ctx, x_coo, shape, &out_coo); + CooToCsrKernel(dev_ctx, out_coo, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index c037f6b1b83..c72a38cd8fd 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(coo_to_csr, GPU, @@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(dense_to_csr, GPU, diff --git a/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu new file mode 100644 index 00000000000..32d842161c2 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +std::vector get_gpu_grad_perm(std::vector perm) { + std::vector grad_perm(perm.size()); + for (unsigned int i = 0; i < perm.size(); ++i) { + grad_perm[perm[i]] = i; + } + return grad_perm; +} + +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx) { + std::vector grad_perm = get_gpu_grad_perm(perm); + TransposeCooKernel(dev_ctx, dout, grad_perm, dx); +} + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx) { + std::vector grad_perm = get_gpu_grad_perm(perm); + TransposeCsrKernel(dev_ctx, dout, grad_perm, dx); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCooGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu new file mode 100644 index 00000000000..692076b80e9 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu @@ -0,0 +1,338 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +__global__ void TransposeCooCudaKernel(const int64_t *x_indices_data, + const int *perm, + const std::size_t n_dim, + const int64_t x_nnz, + int64_t *out_indices_data) { + CUDA_KERNEL_LOOP_TYPE(index, x_nnz * n_dim, int64_t) { + int64_t i = index / x_nnz; + int64_t j = index % x_nnz; + out_indices_data[index] = x_indices_data[j + perm[i] * x_nnz]; + } +} + +template +__global__ void TransposeCsr2DCudaKernel(const int64_t *x_crows_data, + const int64_t *x_cols_data, + const T *x_values_data, + const int *perm, + const int64_t *x_dims, + const int64_t *out_dims, + const int64_t x_nnz, + int64_t *out_crows_data, + int64_t *out_cols_data, + T *out_values_data) { + int64_t __index__ = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + // compute out_crows_data by x_cols_data + for (int64_t i = __index__; i <= out_dims[0]; i += blockDim.x * gridDim.x) { + out_crows_data[i] = 0; + } + __syncthreads(); + if (__index__ == 0) { + for (int64_t i = 0; i < x_nnz; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 2]++; + } + for (int64_t i = 0; i < out_dims[0]; i += 1) { + out_crows_data[i + 1] += out_crows_data[i]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + for (int i = 0; i < x_dims[0]; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j] + 1; + int64_t jjj = out_crows_data[x_cols_j]; + out_cols_data[jjj] = i; + out_values_data[jjj] = x_values_data[j]; + out_crows_data[x_cols_j]++; + } + } + } +} + +template +__global__ void TransposeCsr3DCudaKernel(const int64_t *x_crows_data, + const int64_t *x_cols_data, + const T *x_values_data, + const int *perm, + const int64_t *x_dims, + const int64_t *out_dims, + const std::size_t n_dim, + const int64_t x_nnz, + int64_t *out_crows_data, + int64_t *out_cols_data, + T *out_values_data) { + int64_t __index__ = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (__index__ == 0) { + int out_n_rows = out_dims[1]; + int x_n_rows = x_dims[1]; + for (int k = 0; k < out_dims[0]; ++k) { + if (perm[0] == 0) { // dims == {0, 2, 1} + // compute out_crows_data by x_cols_data + for (int i = 0; i <= out_n_rows; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_crows_data[x_n_rows]; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 2]++; + } + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i + 1] += out_crows_data[i]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + for (int i = 0; i < x_n_rows; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j] + 1; + int64_t jjj = out_crows_data[x_cols_j]; + out_cols_data[jjj] = i; + out_values_data[jjj] = x_values_data[j]; + out_crows_data[x_cols_j]++; + } + } + // x offset + x_cols_data += x_crows_data[x_n_rows]; + x_values_data += x_crows_data[x_n_rows]; + x_crows_data += x_n_rows + 1; + } else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2} + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + int x_cols_offset = 0; + int out_cols_index = 0; + for (int i = 0; i < x_dims[0]; ++i) { + int x_crows_index = i * (x_n_rows + 1); + int start = x_crows_data[x_crows_index + k]; + int end = x_crows_data[x_crows_index + 1 + k]; + out_crows_data[i + 1] = end - start; + for (int j = start; j < end; ++j) { + out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j]; + out_values_data[out_cols_index] = x_values_data[x_cols_offset + j]; + out_cols_index++; + } + x_cols_offset += x_crows_data[x_crows_index + x_n_rows]; + } + for (int i = 1; i <= out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + } + // out offset + out_cols_data += out_crows_data[out_n_rows]; + out_values_data += out_crows_data[out_n_rows]; + out_crows_data += out_n_rows + 1; + } + } +} + +template +void TransposeCooKernel(const Context &dev_ctx, + const SparseCooTensor &x, + const std::vector &perm, + SparseCooTensor *out) { + // create out sparse tensor + int64_t x_nnz = x.nnz(); + std::size_t n_dim = perm.size(); + DDim out_dims = x.dims().transpose(perm); + DenseTensor out_indices = EmptyLike(dev_ctx, x.indices()); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor &x_indices = x.indices(); + const auto *x_indices_data = x_indices.data(); + auto *out_indices_data = out_indices.data(); + int *d_perm; +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + hipMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + cudaMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); +#endif + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz * n_dim, 1); + TransposeCooCudaKernel<<>>( + x_indices_data, d_perm, n_dim, x_nnz, out_indices_data); +} + +template +void TransposeCsrKernel(const Context &dev_ctx, + const SparseCsrTensor &x, + const std::vector &perm, + SparseCsrTensor *out) { + std::size_t n_dim = perm.size(); + const DenseTensor &x_crows = x.crows(); + const DenseTensor &x_cols = x.cols(); + const DenseTensor &x_values = x.non_zero_elements(); + DenseTensor out_crows, out_cols, out_values; + // return a copy of x + if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) { + out_crows = x_crows; + out_cols = x_cols; + out_values = x_values; + out->SetMember(out_crows, out_cols, out_values, x.dims()); + return; + } + // create out sparse tensor + DDim out_dims = x.dims().transpose(perm); + if (n_dim == 2) { + out_crows = Empty(dev_ctx, {out_dims[0] + 1}); + } else { + out_crows = + Empty(dev_ctx, {out_dims[0] * (out_dims[1] + 1)}); + } + out_cols = EmptyLike(dev_ctx, x.cols()); + out_values = EmptyLike(dev_ctx, x.values()); + out->SetMember(out_crows, out_cols, out_values, out_dims); + // transpose by two stages + if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {0, 2, 1}, out); + return; + } else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &temp); + TransposeCsrKernel(dev_ctx, temp, {1, 0, 2}, out); + return; + } else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {2, 0, 1}, out); + return; + } + int64_t *out_crows_data = out_crows.data(); + int64_t *out_cols_data = out_cols.data(); + T *out_values_data = out_values.data(); + const int64_t *x_crows_data = x_crows.data(); + const int64_t *x_cols_data = x_cols.data(); + const T *x_values_data = x_values.data(); + int *d_perm; + int64_t *d_x_dims, *d_out_dims; +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + hipMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&d_x_dims), + sizeof(int64_t) * x.dims().size()); + hipMemcpy(d_x_dims, + x.dims().Get(), + sizeof(int64_t) * x.dims().size(), + hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&d_out_dims), + sizeof(int64_t) * out_dims.size()); + hipMemcpy(d_out_dims, + out_dims.Get(), + sizeof(int64_t) * out_dims.size(), + hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + cudaMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&d_x_dims), + sizeof(int64_t) * x.dims().size()); + cudaMemcpy(d_x_dims, + x.dims().Get(), + sizeof(int64_t) * x.dims().size(), + cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&d_out_dims), + sizeof(int64_t) * out_dims.size()); + cudaMemcpy(d_out_dims, + out_dims.Get(), + sizeof(int64_t) * out_dims.size(), + cudaMemcpyHostToDevice); +#endif + int64_t x_nnz = x.nnz(); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_dims[0], 1); + if (perm.size() == 2) { + TransposeCsr2DCudaKernel<<>>(x_crows_data, + x_cols_data, + x_values_data, + d_perm, + d_x_dims, + d_out_dims, + x_nnz, + out_crows_data, + out_cols_data, + out_values_data); + } else { + TransposeCsr3DCudaKernel<<<1, 1, 0, dev_ctx.stream()>>>(x_crows_data, + x_cols_data, + x_values_data, + d_perm, + d_x_dims, + d_out_dims, + perm.size(), + x_nnz, + out_crows_data, + out_cols_data, + out_values_data); + } +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCooKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.h b/paddle/phi/kernels/sparse/unary_grad_kernel.h index eb2cf9ed697..b446e1b99ed 100644 --- a/paddle/phi/kernels/sparse/unary_grad_kernel.h +++ b/paddle/phi/kernels/sparse/unary_grad_kernel.h @@ -77,5 +77,29 @@ void CastCsrGradKernel(const Context& dev_ctx, DataType value_dtype, SparseCsrTensor* dx); +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx); + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx); + +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx); + +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index fdb6b21a444..a81e724d1fe 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -99,6 +101,48 @@ void CastCsrKernel(const Context& dev_ctx, DataType value_dtype, SparseCsrTensor* out); +template +void TransposeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm, + SparseCooTensor* out); + +template +void TransposeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm, + SparseCsrTensor* out); + +template +SparseCooTensor TransposeCoo(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm) { + PADDLE_ENFORCE_EQ(x.sparse_dim(), + perm.size(), + phi::errors::InvalidArgument( + "size of perm must be equal than the x.sparse_dim()")); + SparseCooTensor coo; + TransposeCooKernel(dev_ctx, x, perm, &coo); + return coo; +} + +template +SparseCsrTensor TransposeCsr(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm) { + PADDLE_ENFORCE_LE( + 2, + perm.size(), + phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); + PADDLE_ENFORCE_GE( + 3, + perm.size(), + phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); + SparseCsrTensor csr; + TransposeCsrKernel(dev_ctx, x, perm, &csr); + return csr; +} + template SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor coo; @@ -113,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) { return csr; } +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out); + +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out); + +template +SparseCooTensor ReshapeCoo(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape) { + SparseCooTensor coo; + ReshapeCooKernel(dev_ctx, x, shape, &coo); + return coo; +} + +template +SparseCsrTensor ReshapeCsr(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape) { + PADDLE_ENFORCE_LE( + 2, + shape.size(), + phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); + PADDLE_ENFORCE_GE( + 3, + shape.size(), + phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); + SparseCsrTensor csr; + ReshapeCsrKernel(dev_ctx, x, shape, &csr); + return csr; +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index d1c9d25483f..600638cd845 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -74,6 +74,10 @@ cc_test( test_sparse_elementwise_dev_api SRCS test_sparse_elementwise_dev_api.cc DEPS phi phi_api_utils) +cc_test( + test_sparse_transpose_dev_api + SRCS test_sparse_transpose_dev_api.cc + DEPS phi phi_api_utils) cc_test( test_math_function diff --git a/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc new file mode 100644 index 00000000000..b2d5ed1d61b --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" +#include "paddle/phi/kernels/transpose_grad_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +namespace phi { +namespace tests { + +TEST(DEV_API, sparse_transpose_coo) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_coo = sparse::DenseToCoo(dev_ctx_cpu, dense_x, 3); + auto sparse_out = + sparse::TransposeCoo(dev_ctx_cpu, sparse_coo, {2, 1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out); + + // backward + DenseTensor dense_grad_x = phi::EmptyLike(dev_ctx_cpu, dense_out); + TransposeGradKernel(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x); + SparseCooTensor sparse_grad_x; + sparse::EmptyLikeCooKernel(dev_ctx_cpu, sparse_coo, &sparse_grad_x); + + SparseCooTensor sparse_out_grad( + sparse_coo.indices(), sparse_coo.values(), {2, 2, 3}); + sparse::TransposeCooGradKernel( + dev_ctx_cpu, sparse_out_grad, {2, 1, 0}, &sparse_grad_x); +} + +TEST(DEV_API, sparse_transpose_csr_case1) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {2, 1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out); + + // backward + DenseTensor dense_grad_x = phi::EmptyLike(dev_ctx_cpu, dense_out); + TransposeGradKernel(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x); + SparseCsrTensor sparse_grad_x; + sparse::EmptyLikeCsrKernel(dev_ctx_cpu, sparse_csr, &sparse_grad_x); + sparse::TransposeCsrGradKernel( + dev_ctx_cpu, sparse_out, {2, 1, 0}, &sparse_grad_x); +} + +TEST(DEV_API, sparse_transpose_csr_case2) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {1, 2, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {1, 2, 0}, &dense_out); +} + +TEST(DEV_API, sparse_transpose_csr_case3) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 4}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({4, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {1, 0}, &dense_out); +} + +} // namespace tests +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_is_same_shape.py b/python/paddle/fluid/tests/unittests/test_sparse_is_same_shape.py new file mode 100644 index 00000000000..aac5d04a763 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_is_same_shape.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import paddle +from paddle.incubate.sparse.binary import is_same_shape + + +class TestSparseIsSameShapeAPI(unittest.TestCase): + """ + test paddle.incubate.sparse.is_same_shape + """ + + def setUp(self): + self.shapes = [[2, 5, 8], [3, 4]] + self.tensors = [ + paddle.rand(self.shapes[0]), + paddle.rand(self.shapes[0]), + paddle.rand(self.shapes[1]) + ] + self.sparse_dim = 2 + + def test_dense_dense(self): + self.assertTrue(is_same_shape(self.tensors[0], self.tensors[1])) + self.assertFalse(is_same_shape(self.tensors[0], self.tensors[2])) + self.assertFalse(is_same_shape(self.tensors[1], self.tensors[2])) + + def test_dense_csr(self): + self.assertTrue( + is_same_shape(self.tensors[0], self.tensors[1].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[0], self.tensors[2].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[1], self.tensors[2].to_sparse_csr())) + + def test_dense_coo(self): + self.assertTrue( + is_same_shape(self.tensors[0], + self.tensors[1].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[0], + self.tensors[2].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[1], + self.tensors[2].to_sparse_coo(self.sparse_dim))) + + def test_csr_dense(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[1])) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[2])) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_csr(), self.tensors[2])) + + def test_csr_csr(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_csr(), + self.tensors[1].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_csr(), + self.tensors[2].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_csr(), + self.tensors[2].to_sparse_csr())) + + def test_csr_coo(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_csr(), + self.tensors[1].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_csr(), + self.tensors[2].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_csr(), + self.tensors[2].to_sparse_coo(self.sparse_dim))) + + def test_coo_dense(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[1])) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[2])) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim), + self.tensors[2])) + + def test_coo_csr(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[1].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[2].to_sparse_csr())) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim), + self.tensors[2].to_sparse_csr())) + + def test_coo_coo(self): + self.assertTrue( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[1].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim), + self.tensors[2].to_sparse_coo(self.sparse_dim))) + self.assertFalse( + is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim), + self.tensors[2].to_sparse_coo(self.sparse_dim))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py new file mode 100644 index 00000000000..e9ef737f774 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +class TestReshape(unittest.TestCase): + """ + Test the API paddle.incubate.sparse.reshape on some sparse tensors. + x: sparse, out: sparse + """ + + def check_result(self, x_shape, new_shape, format): + """ + x_shape: original shape + new_shape: new shape + format: "coo" or "csr" + Transform a sparse tensor with shape "x_shape" to + a sparse tensor with shape "new_shape". + Compare the output of paddle.reshape and the output of + paddle.incubate.sparse.reshape. + """ + mask = np.random.randint(0, 2, x_shape) + np_x = np.random.randint(-100, 100, x_shape) * mask + + # check cpu kernel + dense_x = paddle.to_tensor(np_x, place=paddle.CPUPlace()) + dense_x.stop_gradient = False + dense_out = paddle.reshape(dense_x, new_shape) + + if format == "coo": + sp_x = paddle.to_tensor(np_x, + place=paddle.CPUPlace()).to_sparse_coo( + len(x_shape)) + else: + sp_x = paddle.to_tensor(np_x, + place=paddle.CPUPlace()).to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy() * + np_x.astype('bool').astype('int'), + rtol=1e-05) + + # check gpu kernel + if paddle.device.is_compiled_with_cuda(): + dense_x = paddle.to_tensor(np_x, place=paddle.CUDAPlace(0)) + dense_x.stop_gradient = False + dense_out = paddle.reshape(dense_x, new_shape) + + if format == "coo": + sp_x = paddle.to_tensor( + np_x, place=paddle.CUDAPlace(0)).to_sparse_coo(len(x_shape)) + else: + sp_x = paddle.to_tensor( + np_x, place=paddle.CUDAPlace(0)).to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy() * + np_x.astype('bool').astype('int'), + rtol=1e-05) + + def test_reshape_2d(self): + self.check_result([2, 5], [ + 10, + ], 'coo') + self.check_result([12, 5], [15, 4], 'coo') + + self.check_result([10, 5], [2, 25], 'csr') + self.check_result([9, 8], [18, 4], 'csr') + + def test_reshape_3d(self): + self.check_result([6, 2, 3], [6, 2, 3], 'coo') + self.check_result([6, 2, 3], [2, 3, 3, 2], 'coo') + self.check_result([6, 2, 3], [1, 18, 2], 'coo') + self.check_result([6, 2, 3], [2, 9, 2], 'coo') + self.check_result([6, 2, 3], [2, 1, 18], 'coo') + self.check_result([6, 2, 3], [1, 2, 2, 3, 3], 'coo') + + self.check_result([6, 2, 3], [6, 2, 3], 'csr') + self.check_result([6, 2, 3], [6, 3, 2], 'csr') + self.check_result([6, 2, 3], [2, 6, 3], 'csr') + self.check_result([6, 2, 3], [3, 6, 2], 'csr') + self.check_result([6, 2, 3], [4, 9, 1], 'csr') + self.check_result([6, 2, 3], [12, 1, 3], 'csr') + + def test_reshape_nd(self): + self.check_result([8, 3, 4, 4, 5, 3], [24, 8, 10, 3], 'coo') + self.check_result([3, 4, 4, 5, 7], [1, 12, 2, 5, 14], 'coo') + + def test_reshape_with_zero_or_minus_one_in_new_shape(self): + self.check_result([6, 2, 3], [-1, 0, 3], 'coo') + self.check_result([6, 2, 3], [2, 3, 0, -1], 'coo') + self.check_result([6, 2, 3], [1, -1, 2], 'coo') + self.check_result([6, 2, 3], [-1, 9, 2], 'coo') + self.check_result([6, 2, 3], [2, -1, 18], 'coo') + self.check_result([6, 2, 3], [1, 0, 2, -1, 3], 'coo') + + self.check_result([6, 2, 3], [0, 0, -1], 'csr') + self.check_result([6, 2, 3], [-1, 3, 2], 'csr') + self.check_result([6, 2, 3], [2, -1, 0], 'csr') + self.check_result([6, 2, 3], [-1, 6, 2], 'csr') + self.check_result([6, 2, 3], [-1, 9, 1], 'csr') + self.check_result([6, 2, 3], [-1, 1, 3], 'csr') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py new file mode 100644 index 00000000000..58bcbdc8c00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from paddle.fluid.framework import _test_eager_guard + + +class TestTranspose(unittest.TestCase): + # x: sparse, out: sparse + def check_result(self, x_shape, dims, format): + with _test_eager_guard(): + mask = paddle.randint(0, 2, x_shape).astype("float32") + # "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask", + # or the backward checks may fail. + origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_out = paddle.transpose(dense_x, dims) + + if format == "coo": + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) + else: + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.transpose(sp_x, dims) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy(), + rtol=1e-05) + + def test_transpose_2d(self): + self.check_result([2, 5], [0, 1], 'coo') + self.check_result([2, 5], [0, 1], 'csr') + self.check_result([2, 5], [1, 0], 'coo') + self.check_result([2, 5], [1, 0], 'csr') + + def test_transpose_3d(self): + self.check_result([6, 2, 3], [0, 1, 2], 'coo') + self.check_result([6, 2, 3], [0, 1, 2], 'csr') + self.check_result([6, 2, 3], [0, 2, 1], 'coo') + self.check_result([6, 2, 3], [0, 2, 1], 'csr') + self.check_result([6, 2, 3], [1, 0, 2], 'coo') + self.check_result([6, 2, 3], [1, 0, 2], 'csr') + self.check_result([6, 2, 3], [2, 0, 1], 'coo') + self.check_result([6, 2, 3], [2, 0, 1], 'csr') + self.check_result([6, 2, 3], [2, 1, 0], 'coo') + self.check_result([6, 2, 3], [2, 1, 0], 'csr') + self.check_result([6, 2, 3], [1, 2, 0], 'coo') + self.check_result([6, 2, 3], [1, 2, 0], 'csr') + + def test_transpose_nd(self): + self.check_result([8, 3, 4, 4, 5, 3], [5, 3, 4, 1, 0, 2], 'coo') + # Randint now only supports access to dimension 0 to 9. + self.check_result([2, 3, 4, 2, 3, 4, 2, 3, 4], + [2, 3, 4, 5, 6, 7, 8, 0, 1], 'coo') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 8408c3ca277..8b6866fa4da 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -34,6 +34,8 @@ from .unary import coalesce from .unary import deg2rad from .unary import rad2deg from .unary import expm1 +from .unary import transpose +from .unary import reshape from .binary import mv from .binary import matmul @@ -42,39 +44,16 @@ from .binary import add from .binary import divide from .binary import multiply from .binary import subtract +from .binary import is_same_shape from .multiary import addmm from . import nn __all__ = [ - 'sparse_coo_tensor', - 'sparse_csr_tensor', - 'sin', - 'tan', - 'asin', - 'atan', - 'sinh', - 'tanh', - 'asinh', - 'atanh', - 'sqrt', - 'square', - 'log1p', - 'abs', - 'pow', - 'cast', - 'neg', - 'deg2rad', - 'rad2deg', - 'expm1', - 'mv', - 'matmul', - 'masked_matmul', - 'addmm', - 'add', - 'subtract', - 'multiply', - 'divide', - 'coalesce', + 'sparse_coo_tensor', 'sparse_csr_tensor', 'sin', 'tan', 'asin', 'atan', + 'sinh', 'tanh', 'asinh', 'atanh', 'sqrt', 'square', 'log1p', 'abs', 'pow', + 'cast', 'neg', 'deg2rad', 'rad2deg', 'expm1', 'mv', 'matmul', + 'masked_matmul', 'addmm', 'add', 'subtract', 'transpose', 'multiply', + 'divide', 'coalesce', 'is_same_shape', 'reshape' ] diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py index 39d80508b1c..aa0231924c9 100644 --- a/python/paddle/incubate/sparse/binary.py +++ b/python/paddle/incubate/sparse/binary.py @@ -414,3 +414,36 @@ def divide(x, y, name=None): if y.dtype != x.dtype: y = _C_ops.sparse_cast(y, None, x.dtype) return _C_ops.sparse_divide(x, y) + + +@dygraph_only +def is_same_shape(x, y): + """ + Return the results of shape comparison between two Tensors, check whether x.shape equal to y.shape. + Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are supported. + + Args: + x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor. + y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor. + + Returns: + bool: True for same shape and False for different shape. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.rand([2, 3, 8]) + y = paddle.rand([2, 3, 8]) + y = y.to_sparse_csr() + z = paddle.rand([2, 5]) + + paddle.incubate.sparse.is_same_shape(x, y) + # True + paddle.incubate.sparse.is_same_shape(x, z) + # False + + """ + return x.is_same_shape(y) diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py index 472a71d482b..eac098b2bfc 100644 --- a/python/paddle/incubate/sparse/unary.py +++ b/python/paddle/incubate/sparse/unary.py @@ -119,6 +119,37 @@ def asin(x, name=None): return _C_ops.sparse_asin(x) +@dygraph_only +def transpose(x, perm, name=None): + """ + Changes the perm order of ``x`` without changing its data, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + out = transpose(x, perm) + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64. + perm (list|tuple): Permute the input according to the data of perm. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A transposed Sparse Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([[-2., 0.], [1., 2.]]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.transpose(sparse_x, [1, 0]) + + """ + return _C_ops.sparse_transpose(x, perm) + + @dygraph_only def atan(x, name=None): """ @@ -608,3 +639,60 @@ def expm1(x, name=None): out = paddle.incubate.sparse.expm1(sparse_x) """ return _C_ops.sparse_expm1(x) + + +@dygraph_only +def reshape(x, shape, name=None): + """ + Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor. + Currently this function can only reshape the sparse dims of ``x`` , but ``shape`` argument must be specified + as the shape of the reshaped tensor. + + Note that if x is a SparseCsrTensor, then len(shape) must be 2 or 3. + + There are some tricks when specifying the target shape. + + - 1. -1 means the value of this dimension is inferred from the total element number of x and remaining dimensions. Thus one and only one dimension can be set -1. + + - 2. 0 means the actual dimension value is going to be copied from the corresponding dimension of x. The indices of 0 in the target shape can not exceed the rank of x. + + Here are some examples to explain it. + + - 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [6, 8], the reshape operator will transform x into a 2-D tensor with shape [6, 8] and leaving x's data unchanged. + + - 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [2, 3, -1, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this case, one dimension of the target shape is set to -1, the value of this dimension is inferred from the total element number of x and remaining dimensions. + + - 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x. + + Args: + x (Tensor): The input sparse tensor with data type ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``. + shape (list|tuple): Define the target shape. At most one dimension of the target shape can be -1. + The data type is ``int32``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A reshaped Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + import paddle + + x_shape = [6, 2, 3] + new_shape = [1, 0, 2, -1, 3] + format = "coo" + + dense_x = paddle.randint(-100, 100, x_shape) * paddle.randint(0, 2, x_shape) + + if format == "coo": + sp_x = dense_x.to_sparse_coo(len(x_shape)) + else: + sp_x = dense_x.to_sparse_csr() + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + print(sp_out) + # the shape of sp_out is [1, 2, 2, 3, 3] + + """ + return _C_ops.sparse_reshape(x, shape) -- GitLab