diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 8347ee200e815c505478b977d3058c15234263dc..de49f6f27fe36dd2879b846cb7cea5e491df392f 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -385,6 +385,17 @@ kernel : func : coo_to_dense { sparse_coo -> dense } +- backward_op : transpose_grad + forward : transpose(Tensor x, int[] perm) -> Tensor(out) + args : (Tensor out_grad, int[] perm) + output : Tensor(x_grad) + infer_meta : + func : TransposeGradInferMeta + param : [out_grad, perm] + kernel : + func : transpose_coo_grad {sparse_coo -> sparse_coo}, + transpose_csr_grad {sparse_csr -> sparse_csr} + - backward_op : values_grad forward : values_coo(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index a917012b2f7916b8e6ea8b23cc6e096e4134870d..1d7a4c0bafe53e998120690ac476f7a7e8479a16 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -457,3 +457,15 @@ mv_csr{sparse_csr, dense -> dense} layout : x backward: mv_grad + +- op : transpose + args : (Tensor x, int[] perm) + output : Tensor(out) + infer_meta : + func : TransposeInferMeta + param: [ x, perm ] + kernel : + func : transpose_coo{sparse_coo -> sparse_coo}, + transpose_csr{sparse_csr -> sparse_csr} + layout : x + backward : transpose_grad diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index f8869aa524d3250faffca654342508fe3151f4df..a28229996c887140de13b6f3f509ff918703e7a6 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase, [0, 0, 0, 0]] dims_ = (4, 4) non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]] - non_zero_indices_ = [0, 2], + non_zero_indices_ = [[0, 2], [1, 2]] */ }; diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 056d049942a2b99649742dcc9d4aee18e0c37327..2acb35915a9c36b0f9a45619097e6af30158e835 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase, [0, 0, 4, 0], [0, 5, 0, 6]] dims_ = (4, 4) - non_zero_elements_ = [1, 2, 3, 4, 5 ,6] + non_zero_elements_ = [1, 2, 3, 4, 5, 6] non_zero_crows_ = [0, 1, 3, 4, 6] non_zero_cols_ = [1, 0, 3, 2, 1, 3] */ @@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase, [0, 0, 4, 0], [0, 5, 0, 0]]] dims_ = (2, 4, 4) - non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5] + non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5] non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5] non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1] */ diff --git a/paddle/phi/infermeta/sparse/unary.h b/paddle/phi/infermeta/sparse/unary.h index 880e90b7ae697ffeb25cf7f74642c224798932bd..11961033012ef22ef277501f237fa2383f684448 100644 --- a/paddle/phi/infermeta/sparse/unary.h +++ b/paddle/phi/infermeta/sparse/unary.h @@ -24,5 +24,12 @@ void IndicesInferMeta(const MetaTensor& x, MetaTensor* out); void ValuesInferMeta(const MetaTensor& x, MetaTensor* out); +void TransposeInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out); + +void TransposeGradInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out); } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..87822a9375ef5c4df50b53952962862bc51954df --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +std::vector get_cpu_grad_perm(std::vector perm) { + std::vector grad_perm(perm.size()); + for (unsigned int i = 0; i < perm.size(); ++i) { + grad_perm[perm[i]] = i; + } + return grad_perm; +} + +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx) { + std::vector grad_perm = get_cpu_grad_perm(perm); + TransposeCooKernel(dev_ctx, dout, grad_perm, dx); +} + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx) { + std::vector grad_perm = get_cpu_grad_perm(perm); + TransposeCsrKernel(dev_ctx, dout, grad_perm, dx); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCooGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc b/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dea63ffbce88d991f13a7c1e1fc7aa13a62a315 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/transpose_kernel.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void TransposeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm, + SparseCooTensor* out) { + // create out sparse tensor + int64_t x_nnz = x.nnz(); + DDim out_dims = x.dims().transpose(perm); + DenseTensor out_indices = EmptyLike(dev_ctx, x.indices()); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor& x_indices = x.indices(); + const auto* x_indices_data = x_indices.data(); + auto* out_indices_data = out_indices.data(); + for (unsigned int i = 0; i < perm.size(); ++i) { + for (int64_t j = 0; j < x_nnz; ++j) { + out_indices_data[j + i * x_nnz] = x_indices_data[j + perm[i] * x_nnz]; + } + } +} + +template +void TransposeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm, + SparseCsrTensor* out) { + unsigned int n_dim = perm.size(); + const DenseTensor& x_crows = x.crows(); + const DenseTensor& x_cols = x.cols(); + const DenseTensor& x_values = x.values(); + DenseTensor out_crows, out_cols, out_values; + // return a copy of x + if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) { + out_crows = x_crows; + out_cols = x_cols; + out_values = x_values; + out->SetMember(out_crows, out_cols, out_values, x.dims()); + return; + } + // create out sparse tensor + DDim out_dims = x.dims().transpose(perm); + if (n_dim == 2) { + out_crows = Empty(dev_ctx, {out_dims[0] + 1}); + } else { + out_crows = + Empty(dev_ctx, {out_dims[0] * (out_dims[1] + 1)}); + } + out_cols = EmptyLike(dev_ctx, x.cols()); + out_values = EmptyLike(dev_ctx, x.values()); + out->SetMember(out_crows, out_cols, out_values, out_dims); + // transpose by two stages + if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {0, 2, 1}, out); + return; + } else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &temp); + TransposeCsrKernel(dev_ctx, temp, {1, 0, 2}, out); + return; + } else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {2, 0, 1}, out); + return; + } + + int64_t* out_crows_data = out_crows.data(); + int64_t* out_cols_data = out_cols.data(); + T* out_values_data = out_values.data(); + const int64_t* x_crows_data = x_crows.data(); + const int64_t* x_cols_data = x_cols.data(); + const T* x_values_data = x_values.data(); + + int64_t x_nnz = x.nnz(); + if (n_dim == 2) { // perm == {1, 0} + // compute out_crows_data by x_cols_data + for (int i = 0; i < out_dims[0]; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_nnz; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 1]++; + } + out_crows_data[out_dims[0]] = x_nnz; + for (int i = 1; i < out_dims[0]; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + std::unordered_map cols_offset; + for (int i = 0; i < x.dims()[0]; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j]; + int64_t jjj = out_crows_data[x_cols_j]; + if (cols_offset.count(jjj)) { + cols_offset[jjj]++; + } else { + cols_offset[jjj] = 0; + } + int64_t jjj_offset = jjj + cols_offset[jjj]; + out_cols_data[jjj_offset] = i; + out_values_data[jjj_offset] = x_values_data[j]; + } + } + } else { // n_dim == 3 + int out_n_rows = out_dims[1]; + int x_n_rows = x.dims()[1]; + for (int k = 0; k < out_dims[0]; ++k) { + if (perm[0] == 0) { // perm == {0, 2, 1} + // compute out_crows_data by x_cols_data + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_crows_data[x_n_rows]; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 1]++; + } + out_crows_data[out_n_rows] = x_crows_data[x_n_rows]; + for (int i = 1; i < out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + std::unordered_map cols_offset; + for (int i = 0; i < x_n_rows; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j]; + int64_t jjj = out_crows_data[x_cols_j]; + if (cols_offset.count(jjj)) { + cols_offset[jjj]++; + } else { + cols_offset[jjj] = 0; + } + int64_t jjj_offset = jjj + cols_offset[jjj]; + out_cols_data[jjj_offset] = i; + out_values_data[jjj_offset] = x_values_data[j]; + } + } + // x offset + x_cols_data += x_crows_data[x_n_rows]; + x_values_data += x_crows_data[x_n_rows]; + x_crows_data += x_n_rows + 1; + } else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2} + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + int x_cols_offset = 0; + int out_cols_index = 0; + for (int i = 0; i < x.dims()[0]; ++i) { + int x_crows_index = i * (x_n_rows + 1); + int start = x_crows_data[x_crows_index + k]; + int end = x_crows_data[x_crows_index + 1 + k]; + out_crows_data[i + 1] = end - start; + for (int j = start; j < end; ++j) { + out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j]; + out_values_data[out_cols_index] = x_values_data[x_cols_offset + j]; + out_cols_index++; + } + x_cols_offset += x_crows_data[x_crows_index + x_n_rows]; + } + for (int i = 1; i <= out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + } + // out offset + out_cols_data += out_crows_data[out_n_rows]; + out_values_data += out_crows_data[out_n_rows]; + out_crows_data += out_n_rows + 1; + } + } +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr, + CPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..32d842161c2e54ea1209e10bd7aded497598b492 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +std::vector get_gpu_grad_perm(std::vector perm) { + std::vector grad_perm(perm.size()); + for (unsigned int i = 0; i < perm.size(); ++i) { + grad_perm[perm[i]] = i; + } + return grad_perm; +} + +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx) { + std::vector grad_perm = get_gpu_grad_perm(perm); + TransposeCooKernel(dev_ctx, dout, grad_perm, dx); +} + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx) { + std::vector grad_perm = get_gpu_grad_perm(perm); + TransposeCsrKernel(dev_ctx, dout, grad_perm, dx); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCooGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..692076b80e9efd7b011426bb130ef19c4851aef4 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu @@ -0,0 +1,338 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +__global__ void TransposeCooCudaKernel(const int64_t *x_indices_data, + const int *perm, + const std::size_t n_dim, + const int64_t x_nnz, + int64_t *out_indices_data) { + CUDA_KERNEL_LOOP_TYPE(index, x_nnz * n_dim, int64_t) { + int64_t i = index / x_nnz; + int64_t j = index % x_nnz; + out_indices_data[index] = x_indices_data[j + perm[i] * x_nnz]; + } +} + +template +__global__ void TransposeCsr2DCudaKernel(const int64_t *x_crows_data, + const int64_t *x_cols_data, + const T *x_values_data, + const int *perm, + const int64_t *x_dims, + const int64_t *out_dims, + const int64_t x_nnz, + int64_t *out_crows_data, + int64_t *out_cols_data, + T *out_values_data) { + int64_t __index__ = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + // compute out_crows_data by x_cols_data + for (int64_t i = __index__; i <= out_dims[0]; i += blockDim.x * gridDim.x) { + out_crows_data[i] = 0; + } + __syncthreads(); + if (__index__ == 0) { + for (int64_t i = 0; i < x_nnz; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 2]++; + } + for (int64_t i = 0; i < out_dims[0]; i += 1) { + out_crows_data[i + 1] += out_crows_data[i]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + for (int i = 0; i < x_dims[0]; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j] + 1; + int64_t jjj = out_crows_data[x_cols_j]; + out_cols_data[jjj] = i; + out_values_data[jjj] = x_values_data[j]; + out_crows_data[x_cols_j]++; + } + } + } +} + +template +__global__ void TransposeCsr3DCudaKernel(const int64_t *x_crows_data, + const int64_t *x_cols_data, + const T *x_values_data, + const int *perm, + const int64_t *x_dims, + const int64_t *out_dims, + const std::size_t n_dim, + const int64_t x_nnz, + int64_t *out_crows_data, + int64_t *out_cols_data, + T *out_values_data) { + int64_t __index__ = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (__index__ == 0) { + int out_n_rows = out_dims[1]; + int x_n_rows = x_dims[1]; + for (int k = 0; k < out_dims[0]; ++k) { + if (perm[0] == 0) { // dims == {0, 2, 1} + // compute out_crows_data by x_cols_data + for (int i = 0; i <= out_n_rows; ++i) { + out_crows_data[i] = 0; + } + for (int i = 0; i < x_crows_data[x_n_rows]; ++i) { + int j = x_cols_data[i]; + out_crows_data[j + 2]++; + } + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i + 1] += out_crows_data[i]; + } + // compute out_cols_data and out_values_data by out_crows_data and x + for (int i = 0; i < x_n_rows; ++i) { + int64_t start = x_crows_data[i]; + int64_t end = x_crows_data[i + 1]; + for (int64_t j = start; j < end; ++j) { + int64_t x_cols_j = x_cols_data[j] + 1; + int64_t jjj = out_crows_data[x_cols_j]; + out_cols_data[jjj] = i; + out_values_data[jjj] = x_values_data[j]; + out_crows_data[x_cols_j]++; + } + } + // x offset + x_cols_data += x_crows_data[x_n_rows]; + x_values_data += x_crows_data[x_n_rows]; + x_crows_data += x_n_rows + 1; + } else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2} + for (int i = 0; i < out_n_rows; ++i) { + out_crows_data[i] = 0; + } + int x_cols_offset = 0; + int out_cols_index = 0; + for (int i = 0; i < x_dims[0]; ++i) { + int x_crows_index = i * (x_n_rows + 1); + int start = x_crows_data[x_crows_index + k]; + int end = x_crows_data[x_crows_index + 1 + k]; + out_crows_data[i + 1] = end - start; + for (int j = start; j < end; ++j) { + out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j]; + out_values_data[out_cols_index] = x_values_data[x_cols_offset + j]; + out_cols_index++; + } + x_cols_offset += x_crows_data[x_crows_index + x_n_rows]; + } + for (int i = 1; i <= out_n_rows; ++i) { + out_crows_data[i] += out_crows_data[i - 1]; + } + } + // out offset + out_cols_data += out_crows_data[out_n_rows]; + out_values_data += out_crows_data[out_n_rows]; + out_crows_data += out_n_rows + 1; + } + } +} + +template +void TransposeCooKernel(const Context &dev_ctx, + const SparseCooTensor &x, + const std::vector &perm, + SparseCooTensor *out) { + // create out sparse tensor + int64_t x_nnz = x.nnz(); + std::size_t n_dim = perm.size(); + DDim out_dims = x.dims().transpose(perm); + DenseTensor out_indices = EmptyLike(dev_ctx, x.indices()); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor &x_indices = x.indices(); + const auto *x_indices_data = x_indices.data(); + auto *out_indices_data = out_indices.data(); + int *d_perm; +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + hipMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + cudaMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); +#endif + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz * n_dim, 1); + TransposeCooCudaKernel<<>>( + x_indices_data, d_perm, n_dim, x_nnz, out_indices_data); +} + +template +void TransposeCsrKernel(const Context &dev_ctx, + const SparseCsrTensor &x, + const std::vector &perm, + SparseCsrTensor *out) { + std::size_t n_dim = perm.size(); + const DenseTensor &x_crows = x.crows(); + const DenseTensor &x_cols = x.cols(); + const DenseTensor &x_values = x.non_zero_elements(); + DenseTensor out_crows, out_cols, out_values; + // return a copy of x + if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) { + out_crows = x_crows; + out_cols = x_cols; + out_values = x_values; + out->SetMember(out_crows, out_cols, out_values, x.dims()); + return; + } + // create out sparse tensor + DDim out_dims = x.dims().transpose(perm); + if (n_dim == 2) { + out_crows = Empty(dev_ctx, {out_dims[0] + 1}); + } else { + out_crows = + Empty(dev_ctx, {out_dims[0] * (out_dims[1] + 1)}); + } + out_cols = EmptyLike(dev_ctx, x.cols()); + out_values = EmptyLike(dev_ctx, x.values()); + out->SetMember(out_crows, out_cols, out_values, out_dims); + // transpose by two stages + if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {0, 2, 1}, out); + return; + } else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &temp); + TransposeCsrKernel(dev_ctx, temp, {1, 0, 2}, out); + return; + } else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0} + SparseCsrTensor temp; + TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); + TransposeCsrKernel(dev_ctx, temp, {2, 0, 1}, out); + return; + } + int64_t *out_crows_data = out_crows.data(); + int64_t *out_cols_data = out_cols.data(); + T *out_values_data = out_values.data(); + const int64_t *x_crows_data = x_crows.data(); + const int64_t *x_cols_data = x_cols.data(); + const T *x_values_data = x_values.data(); + int *d_perm; + int64_t *d_x_dims, *d_out_dims; +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + hipMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&d_x_dims), + sizeof(int64_t) * x.dims().size()); + hipMemcpy(d_x_dims, + x.dims().Get(), + sizeof(int64_t) * x.dims().size(), + hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&d_out_dims), + sizeof(int64_t) * out_dims.size()); + hipMemcpy(d_out_dims, + out_dims.Get(), + sizeof(int64_t) * out_dims.size(), + hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); + cudaMemcpy( + d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&d_x_dims), + sizeof(int64_t) * x.dims().size()); + cudaMemcpy(d_x_dims, + x.dims().Get(), + sizeof(int64_t) * x.dims().size(), + cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&d_out_dims), + sizeof(int64_t) * out_dims.size()); + cudaMemcpy(d_out_dims, + out_dims.Get(), + sizeof(int64_t) * out_dims.size(), + cudaMemcpyHostToDevice); +#endif + int64_t x_nnz = x.nnz(); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_dims[0], 1); + if (perm.size() == 2) { + TransposeCsr2DCudaKernel<<>>(x_crows_data, + x_cols_data, + x_values_data, + d_perm, + d_x_dims, + d_out_dims, + x_nnz, + out_crows_data, + out_cols_data, + out_values_data); + } else { + TransposeCsr3DCudaKernel<<<1, 1, 0, dev_ctx.stream()>>>(x_crows_data, + x_cols_data, + x_values_data, + d_perm, + d_x_dims, + d_out_dims, + perm.size(), + x_nnz, + out_crows_data, + out_cols_data, + out_values_data); + } +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(transpose_coo, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCooKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(transpose_csr, + GPU, + ALL_LAYOUT, + phi::sparse::TransposeCsrKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.h b/paddle/phi/kernels/sparse/unary_grad_kernel.h index eb2cf9ed697e9d03b8de7a0714dfdfef71a1fd3c..933e1967e68c334770c04ab7989bc08b58b1be2a 100644 --- a/paddle/phi/kernels/sparse/unary_grad_kernel.h +++ b/paddle/phi/kernels/sparse/unary_grad_kernel.h @@ -77,5 +77,17 @@ void CastCsrGradKernel(const Context& dev_ctx, DataType value_dtype, SparseCsrTensor* dx); +template +void TransposeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& dout, + const std::vector& perm, + SparseCooTensor* dx); + +template +void TransposeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& dout, + const std::vector& perm, + SparseCsrTensor* dx); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index fdb6b21a44427c111a6d06dc254af2aae724cfa7..fb5cd21ed39211c6e49f5b3ab292b2071c19c5a7 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -99,6 +99,48 @@ void CastCsrKernel(const Context& dev_ctx, DataType value_dtype, SparseCsrTensor* out); +template +void TransposeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm, + SparseCooTensor* out); + +template +void TransposeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm, + SparseCsrTensor* out); + +template +SparseCooTensor TransposeCoo(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& perm) { + PADDLE_ENFORCE_EQ(x.sparse_dim(), + perm.size(), + phi::errors::InvalidArgument( + "size of perm must be equal than the x.sparse_dim()")); + SparseCooTensor coo; + TransposeCooKernel(dev_ctx, x, perm, &coo); + return coo; +} + +template +SparseCsrTensor TransposeCsr(const Context& dev_ctx, + const SparseCsrTensor& x, + const std::vector& perm) { + PADDLE_ENFORCE_LE( + 2, + perm.size(), + phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); + PADDLE_ENFORCE_GE( + 3, + perm.size(), + phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); + SparseCsrTensor csr; + TransposeCsrKernel(dev_ctx, x, perm, &csr); + return csr; +} + template SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor coo; diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 09349ef782bba0fa352f43e33b4c66a87ce1d290..96085539877ecafe1ecc8e632b406a6112b7372c 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -74,6 +74,10 @@ cc_test( test_sparse_elementwise_dev_api SRCS test_sparse_elementwise_dev_api.cc DEPS phi phi_api_utils) +cc_test( + test_sparse_transpose_dev_api + SRCS test_sparse_transpose_dev_api.cc + DEPS phi phi_api_utils) cc_test( test_math_function diff --git a/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2d5ed1d61b49467c7822913a91436f3b0ad190e --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" +#include "paddle/phi/kernels/transpose_grad_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +namespace phi { +namespace tests { + +TEST(DEV_API, sparse_transpose_coo) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_coo = sparse::DenseToCoo(dev_ctx_cpu, dense_x, 3); + auto sparse_out = + sparse::TransposeCoo(dev_ctx_cpu, sparse_coo, {2, 1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out); + + // backward + DenseTensor dense_grad_x = phi::EmptyLike(dev_ctx_cpu, dense_out); + TransposeGradKernel(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x); + SparseCooTensor sparse_grad_x; + sparse::EmptyLikeCooKernel(dev_ctx_cpu, sparse_coo, &sparse_grad_x); + + SparseCooTensor sparse_out_grad( + sparse_coo.indices(), sparse_coo.values(), {2, 2, 3}); + sparse::TransposeCooGradKernel( + dev_ctx_cpu, sparse_out_grad, {2, 1, 0}, &sparse_grad_x); +} + +TEST(DEV_API, sparse_transpose_csr_case1) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {2, 1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out); + + // backward + DenseTensor dense_grad_x = phi::EmptyLike(dev_ctx_cpu, dense_out); + TransposeGradKernel(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x); + SparseCsrTensor sparse_grad_x; + sparse::EmptyLikeCsrKernel(dev_ctx_cpu, sparse_csr, &sparse_grad_x); + sparse::TransposeCsrGradKernel( + dev_ctx_cpu, sparse_out, {2, 1, 0}, &sparse_grad_x); +} + +TEST(DEV_API, sparse_transpose_csr_case2) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {1, 2, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {1, 2, 0}, &dense_out); +} + +TEST(DEV_API, sparse_transpose_csr_case3) { + std::vector data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0}; + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + + DenseTensor dense_x = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({3, 4}), DataLayout::NCHW)); + memcpy(dense_x.data(), data.data(), data.size() * sizeof(float)); + auto sparse_csr = sparse::DenseToCsr(dev_ctx_cpu, dense_x); + + auto sparse_out = + sparse::TransposeCsr(dev_ctx_cpu, sparse_csr, {1, 0}); + DenseTensor dense_out = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta( + DataType::FLOAT32, phi::make_ddim({4, 3}), DataLayout::NCHW)); + TransposeKernel(dev_ctx_cpu, dense_x, {1, 0}, &dense_out); +} + +} // namespace tests +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b14d27e605ba33d1800650708e3559f9177ea4b0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest +from paddle.fluid.framework import _test_eager_guard + + +class TestTranspose(unittest.TestCase): + # x: sparse, out: sparse + def check_result(self, x_shape, dims, format): + with _test_eager_guard(): + mask = paddle.randint(0, 2, x_shape).astype("float32") + origin_x = paddle.rand(x_shape, dtype='float32') * mask + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_out = paddle.transpose(dense_x, dims) + + if format == "coo": + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) + else: + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.transpose(sp_x, dims) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy(), + rtol=1e-05) + + def test_transpose_2d(self): + self.check_result([2, 5], [0, 1], 'coo') + self.check_result([2, 5], [0, 1], 'csr') + self.check_result([2, 5], [1, 0], 'coo') + self.check_result([2, 5], [1, 0], 'csr') + + def test_transpose_3d(self): + self.check_result([6, 2, 3], [0, 1, 2], 'coo') + self.check_result([6, 2, 3], [0, 1, 2], 'csr') + self.check_result([6, 2, 3], [0, 2, 1], 'coo') + self.check_result([6, 2, 3], [0, 2, 1], 'csr') + self.check_result([6, 2, 3], [1, 0, 2], 'coo') + self.check_result([6, 2, 3], [1, 0, 2], 'csr') + self.check_result([6, 2, 3], [2, 0, 1], 'coo') + self.check_result([6, 2, 3], [2, 0, 1], 'csr') + self.check_result([6, 2, 3], [2, 1, 0], 'coo') + self.check_result([6, 2, 3], [2, 1, 0], 'csr') + self.check_result([6, 2, 3], [1, 2, 0], 'coo') + self.check_result([6, 2, 3], [1, 2, 0], 'csr') + + def test_transpose_nd(self): + self.check_result([8, 3, 4, 4, 5, 3], [5, 3, 4, 1, 0, 2], 'coo') + # Randint now only supports access to dimension 0 to 9. + self.check_result([2, 3, 4, 2, 3, 4, 2, 3, 4], + [2, 3, 4, 5, 6, 7, 8, 0, 1], 'coo') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index de89f46438a352014daa2309f37045a995022aaa..581310fbbd9b6b548fd4c6866451b370d82a2426 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -34,6 +34,7 @@ from .unary import coalesce from .unary import deg2rad from .unary import rad2deg from .unary import expm1 +from .unary import transpose from .binary import mv from .binary import matmul @@ -75,6 +76,7 @@ __all__ = [ 'addmm', 'add', 'subtract', + 'transpose', 'multiply', 'divide', 'coalesce', diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py index bb18a5715479fb401135c0a7612c65082b5a40bc..7090ef44c75d987b3ac1c63fb36e2d7284859c3d 100644 --- a/python/paddle/incubate/sparse/unary.py +++ b/python/paddle/incubate/sparse/unary.py @@ -119,6 +119,37 @@ def asin(x, name=None): return _C_ops.sparse_asin(x) +@dygraph_only +def transpose(x, perm, name=None): + """ + Changes the perm order of ``x`` without changing its data, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + out = transpose(x, perm) + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64. + perm (list|tuple): Permute the input according to the data of perm. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A transposed Sparse Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([[-2., 0.], [1., 2.]]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.transpose(sparse_x, [1, 0]) + + """ + return _C_ops.sparse_transpose(x, perm) + + @dygraph_only def atan(x, name=None): """