未验证 提交 5fef043d 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[cherry-pick 2.4] add sparse api transpose/reshape/is_same_shape (#47076)

新增sparse.is_same_shape、sparse.reshape、sparse.transpose 三个API
上级 5a44c124
......@@ -1588,6 +1588,15 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_same_shape(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
return ToPyObject(self->tensor.shape() == other.shape());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -1983,6 +1992,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_same_shape",
(PyCFunction)(void (*)(void))tensor_method_is_same_shape,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"to_sparse_csr",
(PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -260,6 +260,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale)
......@@ -385,6 +396,17 @@
kernel :
func : coo_to_dense { sparse_coo -> dense }
- backward_op : transpose_grad
forward : transpose(Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_coo_grad {sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr -> sparse_csr}
- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -457,3 +457,26 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad
- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
......@@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase,
[0, 0, 0, 0]]
dims_ = (4, 4)
non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]]
non_zero_indices_ = [0, 2],
non_zero_indices_ = [[0, 2], [1, 2]]
*/
};
......
......@@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 6]]
dims_ = (4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
non_zero_crows_ = [0, 1, 3, 4, 6]
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
*/
......@@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 0]]]
dims_ = (2, 4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
*/
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();
// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
......@@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr,
CPU,
......@@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr,
CPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
std::vector<int> get_cpu_grad_perm(std::vector<int> perm) {
std::vector<int> grad_perm(perm.size());
for (unsigned int i = 0; i < perm.size(); ++i) {
grad_perm[perm[i]] = i;
}
return grad_perm;
}
template <typename T, typename Context>
void TransposeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
const std::vector<int>& perm,
SparseCooTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCooKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCsrKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(transpose_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(transpose_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void TransposeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int>& perm,
SparseCooTensor* out) {
// create out sparse tensor
int64_t x_nnz = x.nnz();
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_indices = EmptyLike<int64_t, Context>(dev_ctx, x.indices());
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
for (unsigned int i = 0; i < perm.size(); ++i) {
for (int64_t j = 0; j < x_nnz; ++j) {
out_indices_data[j + i * x_nnz] = x_indices_data[j + perm[i] * x_nnz];
}
}
}
template <typename T, typename Context>
void TransposeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& perm,
SparseCsrTensor* out) {
unsigned int n_dim = perm.size();
const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& x_values = x.values();
DenseTensor out_crows, out_cols, out_values;
// return a copy of x
if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) {
out_crows = x_crows;
out_cols = x_cols;
out_values = x_values;
out->SetMember(out_crows, out_cols, out_values, x.dims());
return;
}
// create out sparse tensor
DDim out_dims = x.dims().transpose(perm);
if (n_dim == 2) {
out_crows = Empty<int64_t, Context>(dev_ctx, {out_dims[0] + 1});
} else {
out_crows =
Empty<int64_t, Context>(dev_ctx, {out_dims[0] * (out_dims[1] + 1)});
}
out_cols = EmptyLike<int64_t, Context>(dev_ctx, x.cols());
out_values = EmptyLike<T, Context>(dev_ctx, x.values());
out->SetMember(out_crows, out_cols, out_values, out_dims);
// transpose by two stages
if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {0, 2, 1}, out);
return;
} else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {1, 0, 2}, out);
return;
} else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {2, 0, 1}, out);
return;
}
int64_t* out_crows_data = out_crows.data<int64_t>();
int64_t* out_cols_data = out_cols.data<int64_t>();
T* out_values_data = out_values.data<T>();
const int64_t* x_crows_data = x_crows.data<int64_t>();
const int64_t* x_cols_data = x_cols.data<int64_t>();
const T* x_values_data = x_values.data<T>();
int64_t x_nnz = x.nnz();
if (n_dim == 2) { // perm == {1, 0}
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_dims[0]; ++i) {
out_crows_data[i] = 0;
}
for (int i = 0; i < x_nnz; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_dims[0]] = x_nnz;
for (int i = 1; i < out_dims[0]; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_map<int64_t, int> cols_offset;
for (int i = 0; i < x.dims()[0]; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j];
int64_t jjj = out_crows_data[x_cols_j];
if (cols_offset.count(jjj)) {
cols_offset[jjj]++;
} else {
cols_offset[jjj] = 0;
}
int64_t jjj_offset = jjj + cols_offset[jjj];
out_cols_data[jjj_offset] = i;
out_values_data[jjj_offset] = x_values_data[j];
}
}
} else { // n_dim == 3
int out_n_rows = out_dims[1];
int x_n_rows = x.dims()[1];
for (int k = 0; k < out_dims[0]; ++k) {
if (perm[0] == 0) { // perm == {0, 2, 1}
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
for (int i = 0; i < x_crows_data[x_n_rows]; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_n_rows] = x_crows_data[x_n_rows];
for (int i = 1; i < out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_map<int64_t, int> cols_offset;
for (int i = 0; i < x_n_rows; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j];
int64_t jjj = out_crows_data[x_cols_j];
if (cols_offset.count(jjj)) {
cols_offset[jjj]++;
} else {
cols_offset[jjj] = 0;
}
int64_t jjj_offset = jjj + cols_offset[jjj];
out_cols_data[jjj_offset] = i;
out_values_data[jjj_offset] = x_values_data[j];
}
}
// x offset
x_cols_data += x_crows_data[x_n_rows];
x_values_data += x_crows_data[x_n_rows];
x_crows_data += x_n_rows + 1;
} else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2}
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
int x_cols_offset = 0;
int out_cols_index = 0;
for (int i = 0; i < x.dims()[0]; ++i) {
int x_crows_index = i * (x_n_rows + 1);
int start = x_crows_data[x_crows_index + k];
int end = x_crows_data[x_crows_index + 1 + k];
out_crows_data[i + 1] = end - start;
for (int j = start; j < end; ++j) {
out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j];
out_values_data[out_cols_index] = x_values_data[x_cols_offset + j];
out_cols_index++;
}
x_cols_offset += x_crows_data[x_crows_index + x_n_rows];
}
for (int i = 1; i <= out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
}
// out offset
out_cols_data += out_crows_data[out_n_rows];
out_values_data += out_crows_data[out_n_rows];
out_crows_data += out_n_rows + 1;
}
}
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(transpose_coo,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(transpose_csr,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace phi {
namespace sparse {
__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data,
const int num_x_sparse_part_dims,
const int num_out_sparse_part_dims,
const int64_t x_nnz,
const int64_t* x_sparse_part_strides,
const int64_t* out_sparse_part_strides,
int64_t* out_indices_data) {
CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) {
int64_t location = 0;
for (int i = 0; i < num_x_sparse_part_dims; ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (int i = 0; i < num_out_sparse_part_dims; ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
int64_t x_nnz = x.nnz();
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of out indices
const auto* x_indices_data = x.indices().data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t *destination_x_sparse_part_strides,
*destination_out_sparse_part_strides;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
hipMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
hipMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
cudaMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
cudaMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
#endif
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1);
ReshapeCooCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
x_indices_data,
x_sparse_part_dims.size(),
out_sparse_part_dims.size(),
x_nnz,
destination_x_sparse_part_strides,
destination_out_sparse_part_strides,
out_indices_data);
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(reshape_coo,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(reshape_csr,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
......@@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(coo_to_csr,
GPU,
......@@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}
PD_REGISTER_KERNEL(dense_to_csr,
GPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
std::vector<int> get_gpu_grad_perm(std::vector<int> perm) {
std::vector<int> grad_perm(perm.size());
for (unsigned int i = 0; i < perm.size(); ++i) {
grad_perm[perm[i]] = i;
}
return grad_perm;
}
template <typename T, typename Context>
void TransposeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
const std::vector<int>& perm,
SparseCooTensor* dx) {
std::vector<int> grad_perm = get_gpu_grad_perm(perm);
TransposeCooKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx) {
std::vector<int> grad_perm = get_gpu_grad_perm(perm);
TransposeCsrKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(transpose_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::TransposeCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(transpose_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
__global__ void TransposeCooCudaKernel(const int64_t *x_indices_data,
const int *perm,
const std::size_t n_dim,
const int64_t x_nnz,
int64_t *out_indices_data) {
CUDA_KERNEL_LOOP_TYPE(index, x_nnz * n_dim, int64_t) {
int64_t i = index / x_nnz;
int64_t j = index % x_nnz;
out_indices_data[index] = x_indices_data[j + perm[i] * x_nnz];
}
}
template <typename T>
__global__ void TransposeCsr2DCudaKernel(const int64_t *x_crows_data,
const int64_t *x_cols_data,
const T *x_values_data,
const int *perm,
const int64_t *x_dims,
const int64_t *out_dims,
const int64_t x_nnz,
int64_t *out_crows_data,
int64_t *out_cols_data,
T *out_values_data) {
int64_t __index__ =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
// compute out_crows_data by x_cols_data
for (int64_t i = __index__; i <= out_dims[0]; i += blockDim.x * gridDim.x) {
out_crows_data[i] = 0;
}
__syncthreads();
if (__index__ == 0) {
for (int64_t i = 0; i < x_nnz; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 2]++;
}
for (int64_t i = 0; i < out_dims[0]; i += 1) {
out_crows_data[i + 1] += out_crows_data[i];
}
// compute out_cols_data and out_values_data by out_crows_data and x
for (int i = 0; i < x_dims[0]; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j] + 1;
int64_t jjj = out_crows_data[x_cols_j];
out_cols_data[jjj] = i;
out_values_data[jjj] = x_values_data[j];
out_crows_data[x_cols_j]++;
}
}
}
}
template <typename T>
__global__ void TransposeCsr3DCudaKernel(const int64_t *x_crows_data,
const int64_t *x_cols_data,
const T *x_values_data,
const int *perm,
const int64_t *x_dims,
const int64_t *out_dims,
const std::size_t n_dim,
const int64_t x_nnz,
int64_t *out_crows_data,
int64_t *out_cols_data,
T *out_values_data) {
int64_t __index__ =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (__index__ == 0) {
int out_n_rows = out_dims[1];
int x_n_rows = x_dims[1];
for (int k = 0; k < out_dims[0]; ++k) {
if (perm[0] == 0) { // dims == {0, 2, 1}
// compute out_crows_data by x_cols_data
for (int i = 0; i <= out_n_rows; ++i) {
out_crows_data[i] = 0;
}
for (int i = 0; i < x_crows_data[x_n_rows]; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 2]++;
}
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i + 1] += out_crows_data[i];
}
// compute out_cols_data and out_values_data by out_crows_data and x
for (int i = 0; i < x_n_rows; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j] + 1;
int64_t jjj = out_crows_data[x_cols_j];
out_cols_data[jjj] = i;
out_values_data[jjj] = x_values_data[j];
out_crows_data[x_cols_j]++;
}
}
// x offset
x_cols_data += x_crows_data[x_n_rows];
x_values_data += x_crows_data[x_n_rows];
x_crows_data += x_n_rows + 1;
} else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2}
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
int x_cols_offset = 0;
int out_cols_index = 0;
for (int i = 0; i < x_dims[0]; ++i) {
int x_crows_index = i * (x_n_rows + 1);
int start = x_crows_data[x_crows_index + k];
int end = x_crows_data[x_crows_index + 1 + k];
out_crows_data[i + 1] = end - start;
for (int j = start; j < end; ++j) {
out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j];
out_values_data[out_cols_index] = x_values_data[x_cols_offset + j];
out_cols_index++;
}
x_cols_offset += x_crows_data[x_crows_index + x_n_rows];
}
for (int i = 1; i <= out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
}
// out offset
out_cols_data += out_crows_data[out_n_rows];
out_values_data += out_crows_data[out_n_rows];
out_crows_data += out_n_rows + 1;
}
}
}
template <typename T, typename Context>
void TransposeCooKernel(const Context &dev_ctx,
const SparseCooTensor &x,
const std::vector<int> &perm,
SparseCooTensor *out) {
// create out sparse tensor
int64_t x_nnz = x.nnz();
std::size_t n_dim = perm.size();
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_indices = EmptyLike<int64_t, Context>(dev_ctx, x.indices());
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor &x_indices = x.indices();
const auto *x_indices_data = x_indices.data<int64_t>();
auto *out_indices_data = out_indices.data<int64_t>();
int *d_perm;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
hipMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
cudaMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice);
#endif
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz * n_dim, 1);
TransposeCooCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
x_indices_data, d_perm, n_dim, x_nnz, out_indices_data);
}
template <typename T, typename Context>
void TransposeCsrKernel(const Context &dev_ctx,
const SparseCsrTensor &x,
const std::vector<int> &perm,
SparseCsrTensor *out) {
std::size_t n_dim = perm.size();
const DenseTensor &x_crows = x.crows();
const DenseTensor &x_cols = x.cols();
const DenseTensor &x_values = x.non_zero_elements();
DenseTensor out_crows, out_cols, out_values;
// return a copy of x
if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) {
out_crows = x_crows;
out_cols = x_cols;
out_values = x_values;
out->SetMember(out_crows, out_cols, out_values, x.dims());
return;
}
// create out sparse tensor
DDim out_dims = x.dims().transpose(perm);
if (n_dim == 2) {
out_crows = Empty<int64_t, Context>(dev_ctx, {out_dims[0] + 1});
} else {
out_crows =
Empty<int64_t, Context>(dev_ctx, {out_dims[0] * (out_dims[1] + 1)});
}
out_cols = EmptyLike<int64_t, Context>(dev_ctx, x.cols());
out_values = EmptyLike<T, Context>(dev_ctx, x.values());
out->SetMember(out_crows, out_cols, out_values, out_dims);
// transpose by two stages
if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {0, 2, 1}, out);
return;
} else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {1, 0, 2}, out);
return;
} else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {2, 0, 1}, out);
return;
}
int64_t *out_crows_data = out_crows.data<int64_t>();
int64_t *out_cols_data = out_cols.data<int64_t>();
T *out_values_data = out_values.data<T>();
const int64_t *x_crows_data = x_crows.data<int64_t>();
const int64_t *x_cols_data = x_cols.data<int64_t>();
const T *x_values_data = x_values.data<T>();
int *d_perm;
int64_t *d_x_dims, *d_out_dims;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
hipMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void **>(&d_x_dims),
sizeof(int64_t) * x.dims().size());
hipMemcpy(d_x_dims,
x.dims().Get(),
sizeof(int64_t) * x.dims().size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void **>(&d_out_dims),
sizeof(int64_t) * out_dims.size());
hipMemcpy(d_out_dims,
out_dims.Get(),
sizeof(int64_t) * out_dims.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
cudaMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void **>(&d_x_dims),
sizeof(int64_t) * x.dims().size());
cudaMemcpy(d_x_dims,
x.dims().Get(),
sizeof(int64_t) * x.dims().size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void **>(&d_out_dims),
sizeof(int64_t) * out_dims.size());
cudaMemcpy(d_out_dims,
out_dims.Get(),
sizeof(int64_t) * out_dims.size(),
cudaMemcpyHostToDevice);
#endif
int64_t x_nnz = x.nnz();
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_dims[0], 1);
if (perm.size() == 2) {
TransposeCsr2DCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x_crows_data,
x_cols_data,
x_values_data,
d_perm,
d_x_dims,
d_out_dims,
x_nnz,
out_crows_data,
out_cols_data,
out_values_data);
} else {
TransposeCsr3DCudaKernel<T><<<1, 1, 0, dev_ctx.stream()>>>(x_crows_data,
x_cols_data,
x_values_data,
d_perm,
d_x_dims,
d_out_dims,
perm.size(),
x_nnz,
out_crows_data,
out_cols_data,
out_values_data);
}
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(transpose_coo,
GPU,
ALL_LAYOUT,
phi::sparse::TransposeCooKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(transpose_csr,
GPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
......@@ -77,5 +77,29 @@ void CastCsrGradKernel(const Context& dev_ctx,
DataType value_dtype,
SparseCsrTensor* dx);
template <typename T, typename Context>
void TransposeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
const std::vector<int>& perm,
SparseCooTensor* dx);
template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx);
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx);
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx);
} // namespace sparse
} // namespace phi
......@@ -14,6 +14,8 @@
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
......@@ -99,6 +101,48 @@ void CastCsrKernel(const Context& dev_ctx,
DataType value_dtype,
SparseCsrTensor* out);
template <typename T, typename Context>
void TransposeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int>& perm,
SparseCooTensor* out);
template <typename T, typename Context>
void TransposeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& perm,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCooTensor TransposeCoo(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int>& perm) {
PADDLE_ENFORCE_EQ(x.sparse_dim(),
perm.size(),
phi::errors::InvalidArgument(
"size of perm must be equal than the x.sparse_dim()"));
SparseCooTensor coo;
TransposeCooKernel<T, Context>(dev_ctx, x, perm, &coo);
return coo;
}
template <typename T, typename Context>
SparseCsrTensor TransposeCsr(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& perm) {
PADDLE_ENFORCE_LE(
2,
perm.size(),
phi::errors::InvalidArgument("size of perm must be equal to 2 or 3"));
PADDLE_ENFORCE_GE(
3,
perm.size(),
phi::errors::InvalidArgument("size of perm must be equal to 2 or 3"));
SparseCsrTensor csr;
TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &csr);
return csr;
}
template <typename T, typename Context>
SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) {
SparseCooTensor coo;
......@@ -113,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
return csr;
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out);
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCooTensor ReshapeCoo(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape) {
SparseCooTensor coo;
ReshapeCooKernel<T, Context>(dev_ctx, x, shape, &coo);
return coo;
}
template <typename T, typename Context>
SparseCsrTensor ReshapeCsr(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape) {
PADDLE_ENFORCE_LE(
2,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
PADDLE_ENFORCE_GE(
3,
shape.size(),
phi::errors::InvalidArgument("size of shape must be equal to 2 or 3"));
SparseCsrTensor csr;
ReshapeCsrKernel<T, Context>(dev_ctx, x, shape, &csr);
return csr;
}
} // namespace sparse
} // namespace phi
......@@ -74,6 +74,10 @@ cc_test(
test_sparse_elementwise_dev_api
SRCS test_sparse_elementwise_dev_api.cc
DEPS phi phi_api_utils)
cc_test(
test_sparse_transpose_dev_api
SRCS test_sparse_transpose_dev_api.cc
DEPS phi phi_api_utils)
cc_test(
test_math_function
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace tests {
TEST(DEV_API, sparse_transpose_coo) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_coo = sparse::DenseToCoo<float>(dev_ctx_cpu, dense_x, 3);
auto sparse_out =
sparse::TransposeCoo<float>(dev_ctx_cpu, sparse_coo, {2, 1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out);
// backward
DenseTensor dense_grad_x = phi::EmptyLike<float>(dev_ctx_cpu, dense_out);
TransposeGradKernel<float>(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x);
SparseCooTensor sparse_grad_x;
sparse::EmptyLikeCooKernel<float>(dev_ctx_cpu, sparse_coo, &sparse_grad_x);
SparseCooTensor sparse_out_grad(
sparse_coo.indices(), sparse_coo.values(), {2, 2, 3});
sparse::TransposeCooGradKernel<float>(
dev_ctx_cpu, sparse_out_grad, {2, 1, 0}, &sparse_grad_x);
}
TEST(DEV_API, sparse_transpose_csr_case1) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);
auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {2, 1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out);
// backward
DenseTensor dense_grad_x = phi::EmptyLike<float>(dev_ctx_cpu, dense_out);
TransposeGradKernel<float>(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x);
SparseCsrTensor sparse_grad_x;
sparse::EmptyLikeCsrKernel<float>(dev_ctx_cpu, sparse_csr, &sparse_grad_x);
sparse::TransposeCsrGradKernel<float>(
dev_ctx_cpu, sparse_out, {2, 1, 0}, &sparse_grad_x);
}
TEST(DEV_API, sparse_transpose_csr_case2) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);
auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {1, 2, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {1, 2, 0}, &dense_out);
}
TEST(DEV_API, sparse_transpose_csr_case3) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 4}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);
auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({4, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {1, 0}, &dense_out);
}
} // namespace tests
} // namespace phi
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.incubate.sparse.binary import is_same_shape
class TestSparseIsSameShapeAPI(unittest.TestCase):
"""
test paddle.incubate.sparse.is_same_shape
"""
def setUp(self):
self.shapes = [[2, 5, 8], [3, 4]]
self.tensors = [
paddle.rand(self.shapes[0]),
paddle.rand(self.shapes[0]),
paddle.rand(self.shapes[1])
]
self.sparse_dim = 2
def test_dense_dense(self):
self.assertTrue(is_same_shape(self.tensors[0], self.tensors[1]))
self.assertFalse(is_same_shape(self.tensors[0], self.tensors[2]))
self.assertFalse(is_same_shape(self.tensors[1], self.tensors[2]))
def test_dense_csr(self):
self.assertTrue(
is_same_shape(self.tensors[0], self.tensors[1].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[0], self.tensors[2].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[1], self.tensors[2].to_sparse_csr()))
def test_dense_coo(self):
self.assertTrue(
is_same_shape(self.tensors[0],
self.tensors[1].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[0],
self.tensors[2].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[1],
self.tensors[2].to_sparse_coo(self.sparse_dim)))
def test_csr_dense(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[1]))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_csr(), self.tensors[2]))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_csr(), self.tensors[2]))
def test_csr_csr(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_csr(),
self.tensors[1].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_csr(),
self.tensors[2].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_csr(),
self.tensors[2].to_sparse_csr()))
def test_csr_coo(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_csr(),
self.tensors[1].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_csr(),
self.tensors[2].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_csr(),
self.tensors[2].to_sparse_coo(self.sparse_dim)))
def test_coo_dense(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[1]))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[2]))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
self.tensors[2]))
def test_coo_csr(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[1].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[2].to_sparse_csr()))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
self.tensors[2].to_sparse_csr()))
def test_coo_coo(self):
self.assertTrue(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[1].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[0].to_sparse_coo(self.sparse_dim),
self.tensors[2].to_sparse_coo(self.sparse_dim)))
self.assertFalse(
is_same_shape(self.tensors[1].to_sparse_coo(self.sparse_dim),
self.tensors[2].to_sparse_coo(self.sparse_dim)))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
class TestReshape(unittest.TestCase):
"""
Test the API paddle.incubate.sparse.reshape on some sparse tensors.
x: sparse, out: sparse
"""
def check_result(self, x_shape, new_shape, format):
"""
x_shape: original shape
new_shape: new shape
format: "coo" or "csr"
Transform a sparse tensor with shape "x_shape" to
a sparse tensor with shape "new_shape".
Compare the output of paddle.reshape and the output of
paddle.incubate.sparse.reshape.
"""
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
# check cpu kernel
dense_x = paddle.to_tensor(np_x, place=paddle.CPUPlace())
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_coo(
len(x_shape))
else:
sp_x = paddle.to_tensor(np_x,
place=paddle.CPUPlace()).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
# check gpu kernel
if paddle.device.is_compiled_with_cuda():
dense_x = paddle.to_tensor(np_x, place=paddle.CUDAPlace(0))
dense_x.stop_gradient = False
dense_out = paddle.reshape(dense_x, new_shape)
if format == "coo":
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_coo(len(x_shape))
else:
sp_x = paddle.to_tensor(
np_x, place=paddle.CUDAPlace(0)).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() *
np_x.astype('bool').astype('int'),
rtol=1e-05)
def test_reshape_2d(self):
self.check_result([2, 5], [
10,
], 'coo')
self.check_result([12, 5], [15, 4], 'coo')
self.check_result([10, 5], [2, 25], 'csr')
self.check_result([9, 8], [18, 4], 'csr')
def test_reshape_3d(self):
self.check_result([6, 2, 3], [6, 2, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 3, 2], 'coo')
self.check_result([6, 2, 3], [1, 18, 2], 'coo')
self.check_result([6, 2, 3], [2, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, 1, 18], 'coo')
self.check_result([6, 2, 3], [1, 2, 2, 3, 3], 'coo')
self.check_result([6, 2, 3], [6, 2, 3], 'csr')
self.check_result([6, 2, 3], [6, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, 6, 3], 'csr')
self.check_result([6, 2, 3], [3, 6, 2], 'csr')
self.check_result([6, 2, 3], [4, 9, 1], 'csr')
self.check_result([6, 2, 3], [12, 1, 3], 'csr')
def test_reshape_nd(self):
self.check_result([8, 3, 4, 4, 5, 3], [24, 8, 10, 3], 'coo')
self.check_result([3, 4, 4, 5, 7], [1, 12, 2, 5, 14], 'coo')
def test_reshape_with_zero_or_minus_one_in_new_shape(self):
self.check_result([6, 2, 3], [-1, 0, 3], 'coo')
self.check_result([6, 2, 3], [2, 3, 0, -1], 'coo')
self.check_result([6, 2, 3], [1, -1, 2], 'coo')
self.check_result([6, 2, 3], [-1, 9, 2], 'coo')
self.check_result([6, 2, 3], [2, -1, 18], 'coo')
self.check_result([6, 2, 3], [1, 0, 2, -1, 3], 'coo')
self.check_result([6, 2, 3], [0, 0, -1], 'csr')
self.check_result([6, 2, 3], [-1, 3, 2], 'csr')
self.check_result([6, 2, 3], [2, -1, 0], 'csr')
self.check_result([6, 2, 3], [-1, 6, 2], 'csr')
self.check_result([6, 2, 3], [-1, 9, 1], 'csr')
self.check_result([6, 2, 3], [-1, 1, 3], 'csr')
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from paddle.fluid.framework import _test_eager_guard
class TestTranspose(unittest.TestCase):
# x: sparse, out: sparse
def check_result(self, x_shape, dims, format):
with _test_eager_guard():
mask = paddle.randint(0, 2, x_shape).astype("float32")
# "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask",
# or the backward checks may fail.
origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask
dense_x = origin_x.detach()
dense_x.stop_gradient = False
dense_out = paddle.transpose(dense_x, dims)
if format == "coo":
sp_x = origin_x.detach().to_sparse_coo(len(x_shape))
else:
sp_x = origin_x.detach().to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.incubate.sparse.transpose(sp_x, dims)
np.testing.assert_allclose(sp_out.to_dense().numpy(),
dense_out.numpy(),
rtol=1e-05)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(sp_x.grad.to_dense().numpy(),
(dense_x.grad * mask).numpy(),
rtol=1e-05)
def test_transpose_2d(self):
self.check_result([2, 5], [0, 1], 'coo')
self.check_result([2, 5], [0, 1], 'csr')
self.check_result([2, 5], [1, 0], 'coo')
self.check_result([2, 5], [1, 0], 'csr')
def test_transpose_3d(self):
self.check_result([6, 2, 3], [0, 1, 2], 'coo')
self.check_result([6, 2, 3], [0, 1, 2], 'csr')
self.check_result([6, 2, 3], [0, 2, 1], 'coo')
self.check_result([6, 2, 3], [0, 2, 1], 'csr')
self.check_result([6, 2, 3], [1, 0, 2], 'coo')
self.check_result([6, 2, 3], [1, 0, 2], 'csr')
self.check_result([6, 2, 3], [2, 0, 1], 'coo')
self.check_result([6, 2, 3], [2, 0, 1], 'csr')
self.check_result([6, 2, 3], [2, 1, 0], 'coo')
self.check_result([6, 2, 3], [2, 1, 0], 'csr')
self.check_result([6, 2, 3], [1, 2, 0], 'coo')
self.check_result([6, 2, 3], [1, 2, 0], 'csr')
def test_transpose_nd(self):
self.check_result([8, 3, 4, 4, 5, 3], [5, 3, 4, 1, 0, 2], 'coo')
# Randint now only supports access to dimension 0 to 9.
self.check_result([2, 3, 4, 2, 3, 4, 2, 3, 4],
[2, 3, 4, 5, 6, 7, 8, 0, 1], 'coo')
if __name__ == "__main__":
unittest.main()
......@@ -34,6 +34,8 @@ from .unary import coalesce
from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1
from .unary import transpose
from .unary import reshape
from .binary import mv
from .binary import matmul
......@@ -42,39 +44,16 @@ from .binary import add
from .binary import divide
from .binary import multiply
from .binary import subtract
from .binary import is_same_shape
from .multiary import addmm
from . import nn
__all__ = [
'sparse_coo_tensor',
'sparse_csr_tensor',
'sin',
'tan',
'asin',
'atan',
'sinh',
'tanh',
'asinh',
'atanh',
'sqrt',
'square',
'log1p',
'abs',
'pow',
'cast',
'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv',
'matmul',
'masked_matmul',
'addmm',
'add',
'subtract',
'multiply',
'divide',
'coalesce',
'sparse_coo_tensor', 'sparse_csr_tensor', 'sin', 'tan', 'asin', 'atan',
'sinh', 'tanh', 'asinh', 'atanh', 'sqrt', 'square', 'log1p', 'abs', 'pow',
'cast', 'neg', 'deg2rad', 'rad2deg', 'expm1', 'mv', 'matmul',
'masked_matmul', 'addmm', 'add', 'subtract', 'transpose', 'multiply',
'divide', 'coalesce', 'is_same_shape', 'reshape'
]
......@@ -414,3 +414,36 @@ def divide(x, y, name=None):
if y.dtype != x.dtype:
y = _C_ops.sparse_cast(y, None, x.dtype)
return _C_ops.sparse_divide(x, y)
@dygraph_only
def is_same_shape(x, y):
"""
Return the results of shape comparison between two Tensors, check whether x.shape equal to y.shape.
Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are supported.
Args:
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
Returns:
bool: True for same shape and False for different shape.
Examples:
.. code-block:: python
import paddle
x = paddle.rand([2, 3, 8])
y = paddle.rand([2, 3, 8])
y = y.to_sparse_csr()
z = paddle.rand([2, 5])
paddle.incubate.sparse.is_same_shape(x, y)
# True
paddle.incubate.sparse.is_same_shape(x, z)
# False
"""
return x.is_same_shape(y)
......@@ -119,6 +119,37 @@ def asin(x, name=None):
return _C_ops.sparse_asin(x)
@dygraph_only
def transpose(x, perm, name=None):
"""
Changes the perm order of ``x`` without changing its data, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
out = transpose(x, perm)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
perm (list|tuple): Permute the input according to the data of perm.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A transposed Sparse Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([[-2., 0.], [1., 2.]])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.transpose(sparse_x, [1, 0])
"""
return _C_ops.sparse_transpose(x, perm)
@dygraph_only
def atan(x, name=None):
"""
......@@ -608,3 +639,60 @@ def expm1(x, name=None):
out = paddle.incubate.sparse.expm1(sparse_x)
"""
return _C_ops.sparse_expm1(x)
@dygraph_only
def reshape(x, shape, name=None):
"""
Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor.
Currently this function can only reshape the sparse dims of ``x`` , but ``shape`` argument must be specified
as the shape of the reshaped tensor.
Note that if x is a SparseCsrTensor, then len(shape) must be 2 or 3.
There are some tricks when specifying the target shape.
- 1. -1 means the value of this dimension is inferred from the total element number of x and remaining dimensions. Thus one and only one dimension can be set -1.
- 2. 0 means the actual dimension value is going to be copied from the corresponding dimension of x. The indices of 0 in the target shape can not exceed the rank of x.
Here are some examples to explain it.
- 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [6, 8], the reshape operator will transform x into a 2-D tensor with shape [6, 8] and leaving x's data unchanged.
- 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [2, 3, -1, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this case, one dimension of the target shape is set to -1, the value of this dimension is inferred from the total element number of x and remaining dimensions.
- 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x.
Args:
x (Tensor): The input sparse tensor with data type ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``.
shape (list|tuple): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A reshaped Tensor with the same data type as ``x``.
Examples:
.. code-block:: python
import paddle
x_shape = [6, 2, 3]
new_shape = [1, 0, 2, -1, 3]
format = "coo"
dense_x = paddle.randint(-100, 100, x_shape) * paddle.randint(0, 2, x_shape)
if format == "coo":
sp_x = dense_x.to_sparse_coo(len(x_shape))
else:
sp_x = dense_x.to_sparse_csr()
sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape)
print(sp_out)
# the shape of sp_out is [1, 2, 2, 3, 3]
"""
return _C_ops.sparse_reshape(x, shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册