未验证 提交 d71baff6 编写于 作者: S Scotty 提交者: GitHub

【Hackathon 4th No.29】为 Paddle 新增 paddle.sparse.slice 稀疏 API (#53794)

上级 1f82bc37
...@@ -463,3 +463,14 @@ ...@@ -463,3 +463,14 @@
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax layout : softmax
data_type: query data_type: query
- backward_op: slice_grad
forward : slice(Tensor x, IntArray axes, IntArray starts, IntArray ends) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray axes, IntArray starts, IntArray ends)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : slice_coo_grad{sparse_coo, sparse_coo -> sparse_coo},
slice_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
...@@ -526,3 +526,15 @@ ...@@ -526,3 +526,15 @@
mv_csr{sparse_csr, dense -> dense} mv_csr{sparse_csr, dense -> dense}
layout : x layout : x
backward: mv_grad backward: mv_grad
- op: slice
args : (Tensor x, IntArray axes, IntArray starts, IntArray ends)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : slice_coo{sparse_coo -> sparse_coo},
slice_csr{sparse_csr -> sparse_csr}
layout : x
backward : slice_grad
...@@ -215,5 +215,62 @@ inline DDim GetDecreasedDims(const DDim slice_dims, ...@@ -215,5 +215,62 @@ inline DDim GetDecreasedDims(const DDim slice_dims,
return decreased_dims; return decreased_dims;
} }
template <typename T = int64_t>
inline void CheckAndUpdateSparseSliceAttrs(const DDim in_dims,
std::vector<T>* axes,
std::vector<T>* starts,
std::vector<T>* ends) {
int64_t rank = int64_t(in_dims.size());
for (auto& axis : *axes) {
if (axis < 0) {
axis = std::max(int64_t(0), axis + rank);
}
}
PADDLE_ENFORCE_EQ(
axes->size(),
starts->size(),
phi::errors::InvalidArgument(
"The length of axes (%d) and length of starts (%d) should be same.",
axes->size(),
starts->size()));
PADDLE_ENFORCE_EQ(
axes->size(),
ends->size(),
phi::errors::InvalidArgument(
"The length of axes (%d) and length of ends (%d) should be same.",
axes->size(),
ends->size()));
CheckAndUpdateSliceAttrs<T>(in_dims, *axes, starts, ends);
}
inline void ConstructNewSliceAttrs(const phi::DDim& x_dims,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
std::vector<int64_t>* new_axes,
std::vector<int64_t>* new_starts,
std::vector<int64_t>* new_ends) {
for (int64_t i = 0; i < x_dims.size(); ++i) {
int pos = -1;
for (int j = 0; j < static_cast<int>(axes.size()); ++j) {
if (axes[j] == i) {
pos = j;
break;
}
}
if (pos == -1) {
(*new_axes)[i] = i;
(*new_starts)[i] = 0;
(*new_ends)[i] = x_dims[i];
} else {
(*new_axes)[i] = axes[pos];
(*new_starts)[i] = starts[pos];
(*new_ends)[i] = ends[pos];
}
}
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
// set x_grad
const int64_t out_grad_nnz = out_grad.nnz();
auto sparse_dim = static_cast<int64_t>(out_grad.sparse_dim());
DenseTensor dx_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_indices_data = dx_indices.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
const auto* out_grad_indices_data = out_grad.indices().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
for (int64_t j = 0; j < out_grad_nnz; ++j) {
// set indices
for (int64_t i = 0; i < sparse_dim; ++i) {
dx_indices_data[i * out_grad_nnz + j] =
out_grad_indices_data[i * out_grad_nnz + j];
}
for (size_t ii = 0; ii < axes.size(); ++ii) {
int64_t i = axes[ii];
dx_indices_data[i * out_grad_nnz + j] += starts[ii];
}
// set value
dx_values_data[j] = out_grad_values_data[j];
}
x_grad->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
}
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
template <typename T>
void GetCsrInputGradCrows(const int64_t* out_grad_crows_data,
const int64_t out_grad_n_rows,
const int64_t x_n_rows,
const int64_t rows_start,
int64_t* dx_crows_data,
const int64_t out_grad_crows_offset = 0,
const int64_t dx_crows_offset = 0) {
for (int64_t i = 0; i < x_n_rows + 1; ++i) {
int64_t idx = i + dx_crows_offset;
if (i < rows_start) {
dx_crows_data[idx] = 0;
} else if (i < rows_start + out_grad_n_rows + 1) {
int64_t out_grad_idx = out_grad_crows_offset + (i - rows_start);
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
} else {
int64_t out_grad_idx = out_grad_crows_offset + out_grad_n_rows;
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
}
}
}
template <typename T, typename Context>
void SliceCsrGrad2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t out_grad_nnz = out_grad.nnz();
const int64_t n_rows = x.dims()[0];
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {n_rows + 1});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
// set cols
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_cols_data[i] = out_grad_cols_data[i] + starts[1];
}
// set values
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_values_data[i] = out_grad_values_data[i];
}
// set crows
const int64_t out_grad_n_rows = out_grad.dims()[0];
GetCsrInputGradCrows<T>(out_grad_crows_data,
out_grad_n_rows,
n_rows,
starts[0],
dx_crows_data,
0,
0);
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
}
template <typename T, typename Context>
void SliceCsrGrad3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t dim0 = x.dims()[0], n_rows = x.dims()[1];
const int64_t out_grad_nnz = out_grad.nnz();
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {dim0 * (n_rows + 1)});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
// set cols
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_cols_data[i] = out_grad_cols_data[i] + starts[2];
}
// set values
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_values_data[i] = out_grad_values_data[i];
}
// set crows
int64_t out_grad_n_rows = out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
if (i < starts[0] || i >= ends[0]) {
for (int64_t j = 0; j < n_rows + 1; ++j) {
dx_crows_data[i * (n_rows + 1) + j] = 0;
}
} else {
int64_t out_grad_crows_offset = (i - starts[0]) * (out_grad_n_rows + 1);
int64_t dx_crows_offset = i * (n_rows + 1);
GetCsrInputGradCrows<T>(out_grad_crows_data,
out_grad_n_rows,
n_rows,
starts[1],
dx_crows_data,
out_grad_crows_offset,
dx_crows_offset);
}
}
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
}
template <typename T, typename Context>
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// Construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
const int64_t sparse_dim = x_dims.size();
if (sparse_dim == 2) {
SliceCsrGrad2D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else if (sparse_dim == 3) {
SliceCsrGrad3D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice grad for Sparse CSR Tensor only support 2-D or 3-D, but got "
"%d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SliceCooCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);
// Step2: Get out_nnz (the number of non-zero elements in output)
const int64_t x_nnz = x.nnz();
int64_t out_nnz = 0;
const auto* x_indices_data = x.indices().data<int64_t>();
for (int64_t j = 0; j < x_nnz; ++j) {
bool hit = true;
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto item = x_indices_data[axes[ii] * x_nnz + j];
if (!(starts[ii] <= item && item < ends[ii])) {
hit = false;
break;
}
}
if (!hit) continue;
out_nnz++;
}
// Step3: Get the values and indices of output
auto sparse_dim = static_cast<int64_t>(x.sparse_dim());
DenseTensor out_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
auto* out_indices_data = out_indices.data<int64_t>();
auto* out_values_data = out_values.data<T>();
const auto* x_values_data = x.values().data<T>();
int64_t index = 0;
for (int64_t j = 0; j < x_nnz && index < out_nnz; ++j) {
bool hit = true;
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto item = x_indices_data[axes[ii] * x_nnz + j];
if (!(starts[ii] <= item && item < ends[ii])) {
hit = false;
break;
}
}
if (!hit) continue;
// set value
out_values_data[index] = x_values_data[j];
// set coordinate
for (int64_t i = 0; i < sparse_dim; ++i) {
out_indices_data[i * out_nnz + index] = x_indices_data[i * x_nnz + j];
}
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto i = axes[ii];
out_indices_data[i * out_nnz + index] -= starts[ii];
}
index++;
}
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}
template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}
int64_t GetCsrNonZeroNumber(const SparseCsrTensor& x,
const int64_t x_crows_start,
const int64_t x_crows_end,
const int64_t min_col,
const int64_t max_col,
const int64_t x_cols_offset = 0) {
const auto* x_crows_data = x.crows().data<int64_t>();
const auto* x_cols_data = x.cols().data<int64_t>();
int64_t out_nnz = 0;
for (int64_t i = x_crows_start; i < x_crows_end; ++i) {
int64_t st = x_crows_data[i] + x_cols_offset;
int64_t ed = x_crows_data[i + 1] + x_cols_offset;
for (int64_t jj = st; jj < ed; ++jj) {
if (x_cols_data[jj] >= min_col && x_cols_data[jj] < max_col) {
out_nnz++;
}
}
}
return out_nnz;
}
template <typename T>
void GetCsrSubMatrix(const SparseCsrTensor& x,
const int64_t x_crows_start,
const int64_t x_crows_end,
const int64_t min_col,
const int64_t max_col,
DenseTensor* out_crows,
DenseTensor* out_cols,
DenseTensor* out_values,
const int64_t x_cols_offset = 0,
const int64_t out_crows_offset = 0,
const int64_t out_cols_offset = 0) {
const auto* x_crows_data = x.crows().data<int64_t>();
const auto* x_cols_data = x.cols().data<int64_t>();
const auto* x_values_data = x.values().data<T>();
auto* out_crows_data = out_crows->data<int64_t>();
auto* out_cols_data = out_cols->data<int64_t>();
auto* out_values_data = out_values->data<T>();
out_crows_data[out_crows_offset] = 0;
int64_t index = 0, out_n_rows = x_crows_end - x_crows_start;
for (int i = 0; i < out_n_rows; ++i) {
int64_t st = x_crows_data[x_crows_start + i] + x_cols_offset;
int64_t ed = x_crows_data[x_crows_start + i + 1] + x_cols_offset;
for (int64_t jj = st; jj < ed; ++jj) {
if (x_cols_data[jj] >= min_col && x_cols_data[jj] < max_col) {
out_cols_data[out_cols_offset + index] = x_cols_data[jj] - min_col;
out_values_data[out_cols_offset + index] = x_values_data[jj];
index++;
}
}
out_crows_data[out_crows_offset + i + 1] = index;
}
}
template <typename T, typename Context>
void SliceCsrTensor2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const phi::DDim& out_dims,
SparseCsrTensor* out) {
// Step1: Get nnz of out
int64_t out_nnz =
GetCsrNonZeroNumber(x, starts[0], ends[0], starts[1], ends[1], 0);
// Step2: Set out
int64_t out_n_rows = ends[0] - starts[0];
DenseTensor out_crows =
phi::Empty<int64_t, Context>(dev_ctx, {out_n_rows + 1});
DenseTensor out_cols = phi::Empty<int64_t, Context>(dev_ctx, {out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
GetCsrSubMatrix<T>(x,
starts[0],
ends[0],
starts[1],
ends[1],
&out_crows,
&out_cols,
&out_values,
0,
0,
0);
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SliceCsrTensor3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const phi::DDim& out_dims,
SparseCsrTensor* out) {
const auto* x_crows_data = x.crows().data<int64_t>();
// Step1: Get nnz of out
const int64_t x_dim0 = x.dims()[0], x_n_rows = x.dims()[1];
int64_t x_cols_offset = 0, out_nnz = 0;
// all_nnzs stores the nnz along with out_dim0, which will be used to set out.
std::vector<int64_t> all_nnzs(ends[0] - starts[0]);
for (int64_t i = 0; i < x_dim0; ++i) {
if (i >= starts[0] && i < ends[0]) { // slice dim 0
int64_t x_crows_st = i * (x_n_rows + 1) + starts[1];
int64_t x_crows_ed = i * (x_n_rows + 1) + ends[1];
int64_t nnz = GetCsrNonZeroNumber(
x, x_crows_st, x_crows_ed, starts[2], ends[2], x_cols_offset);
out_nnz += nnz;
all_nnzs[i - starts[0]] = nnz;
}
// get the start index in non_zero_cols_
x_cols_offset += x_crows_data[(i + 1) * (x_n_rows + 1) - 1];
}
// Step2: Set out
const int64_t out_dim0 = out_dims[0], out_n_rows = out_dims[1];
DenseTensor out_crows =
phi::Empty<int64_t, Context>(dev_ctx, {out_dim0 * (out_n_rows + 1)});
DenseTensor out_cols = phi::Empty<int64_t, Context>(dev_ctx, {out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
x_cols_offset = 0;
int64_t out_crows_offset = 0, out_cols_offset = 0;
for (int64_t i = 0; i < x_dim0; ++i) {
if (i >= starts[0] && i < ends[0]) { // slice dim 0
int64_t x_crows_start = i * (x_n_rows + 1) + starts[1];
int64_t x_crows_end = i * (x_n_rows + 1) + ends[1];
GetCsrSubMatrix<T>(x,
x_crows_start,
x_crows_end,
starts[2],
ends[2],
&out_crows,
&out_cols,
&out_values,
x_cols_offset,
out_crows_offset,
out_cols_offset);
out_crows_offset += (out_n_rows + 1);
out_cols_offset += all_nnzs[i - starts[0]];
}
x_cols_offset += x_crows_data[(i + 1) * (x_n_rows + 1) - 1];
}
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SliceCsrCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);
// Step2: Construct new axes, starts and ends.
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
// Setp3: Slice csr tensor according to its dimension
if (x_dims.size() == 2) {
SliceCsrTensor2D<T, Context>(
dev_ctx, x, new_axes, new_starts, new_ends, out_dims, out);
} else if (x_dims.size() == 3) {
SliceCsrTensor3D<T, Context>(
dev_ctx, x, new_axes, new_starts, new_ends, out_dims, out);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice for Sparse CSR Tensor only support 2-D or 3-D, but got %d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T>
__global__ void GetCooInputGradCudaKernel(const int64_t* out_grad_indices_data,
const T* out_grad_values_data,
const int64_t* axes,
const int64_t* starts,
const int64_t axes_size,
const int64_t sparse_dim,
const int64_t out_grad_nnz,
int64_t* dx_indices_data,
T* dx_values_data) {
CUDA_KERNEL_LOOP_TYPE(j, out_grad_nnz, int64_t) {
// set indices
for (int64_t i = 0; i < sparse_dim; ++i) {
dx_indices_data[i * out_grad_nnz + j] =
out_grad_indices_data[i * out_grad_nnz + j];
}
for (size_t ii = 0; ii < axes_size; ++ii) {
int64_t i = axes[ii];
dx_indices_data[i * out_grad_nnz + j] += starts[ii];
}
// set value
dx_values_data[j] = out_grad_values_data[j];
}
}
template <typename T, typename Context>
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// copy axes to device
auto d_axes_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * axes.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* d_axes = reinterpret_cast<int64_t*>(d_axes_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_axes,
phi::CPUPlace(),
axes.data(),
sizeof(int64_t) * axes.size(),
dev_ctx.stream());
// copy starts to device
auto d_starts_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * starts.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_starts,
phi::CPUPlace(),
starts.data(),
sizeof(int64_t) * starts.size(),
dev_ctx.stream());
// Step2: Set indices and values of x_grad
const int64_t out_grad_nnz = out_grad.nnz();
auto sparse_dim = static_cast<int64_t>(out_grad.sparse_dim());
DenseTensor dx_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_indices_data = dx_indices.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
const auto* out_grad_indices_data = out_grad.indices().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
x_grad->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCooInputGradCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_indices_data,
out_grad_values_data,
d_axes,
d_starts,
axes.size(),
sparse_dim,
out_grad_nnz,
dx_indices_data,
dx_values_data);
}
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update sparse slice attrs
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
template <typename T>
__global__ void GetCsrInputColsValuesCudaKernel(
const int64_t* out_grad_cols_data,
const T* out_grad_values_data,
const int64_t out_grad_nnz,
const int64_t cols_start,
int64_t* dx_cols_data,
T* dx_values_data) {
CUDA_KERNEL_LOOP_TYPE(i, out_grad_nnz, int64_t) {
dx_cols_data[i] = out_grad_cols_data[i] + cols_start;
dx_values_data[i] = out_grad_values_data[i];
}
}
__global__ void GetCsrInputCrowsCudaKernel(
const int64_t* out_grad_crows_data,
const int64_t out_grad_n_rows,
const int64_t out_grad_nnz,
const int64_t x_n_rows,
const int64_t rows_start,
const int64_t rows_end,
int64_t* dx_crows_data,
const int64_t dx_crows_offset = 0,
const int64_t out_grad_crows_offset = 0) {
CUDA_KERNEL_LOOP_TYPE(i, x_n_rows + 1, int64_t) {
int64_t idx = i + dx_crows_offset;
if (i < rows_start) {
dx_crows_data[idx] = 0;
} else if (i < rows_start + out_grad_n_rows + 1) {
int64_t out_grad_idx = out_grad_crows_offset + (i - rows_start);
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
} else {
int64_t out_grad_idx = out_grad_crows_offset + out_grad_n_rows;
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
}
}
}
template <typename T, typename Context>
void SliceCsrGrad2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t out_grad_nnz = out_grad.nnz();
const int64_t n_rows = x.dims()[0];
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {n_rows + 1});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
// set cols and values
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCsrInputColsValuesCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_cols_data,
out_grad_values_data,
out_grad_nnz,
starts[1],
dx_cols_data,
dx_values_data);
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_crows_data,
out_grad.dims()[0],
out_grad_nnz,
x.dims()[0],
starts[0],
ends[0],
dx_crows_data,
0,
0);
}
__global__ void GetCsrInputCrowsPart1CudaKernl(const int64_t n_rows,
const int64_t dim0_idx,
int64_t* dx_crows_data) {
CUDA_KERNEL_LOOP_TYPE(j, n_rows + 1, int64_t) {
dx_crows_data[dim0_idx * (n_rows + 1) + j] = 0;
}
}
template <typename T, typename Context>
void SliceCsrGrad3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t dim0 = x.dims()[0], n_rows = x.dims()[1];
const int64_t out_grad_nnz = out_grad.nnz();
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {dim0 * (n_rows + 1)});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
// set cols and values
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCsrInputColsValuesCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_cols_data,
out_grad_values_data,
out_grad_nnz,
starts[2],
dx_cols_data,
dx_values_data);
// set crows
int64_t out_grad_n_rows = out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
if (i < starts[0] || i >= ends[0]) {
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsPart1CudaKernl<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
n_rows, i, dx_crows_data);
} else {
int64_t dx_crows_offset = i * (n_rows + 1);
int64_t out_grad_crows_offset = (i - starts[0]) * (out_grad_n_rows + 1);
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_crows_data,
out_grad_n_rows,
out_grad_nnz,
n_rows,
starts[1],
ends[1],
dx_crows_data,
dx_crows_offset,
out_grad_crows_offset);
}
}
}
template <typename T, typename Context>
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
const int64_t sparse_dim = x_dims.size();
if (sparse_dim == 2) {
SliceCsrGrad2D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else if (sparse_dim == 3) {
SliceCsrGrad3D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice grad for Sparse CSR Tensor only support 2-D or 3-D, but got "
"%d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SliceCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SliceCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
此差异已折叠。
...@@ -121,5 +121,22 @@ void ReshapeCsrGradKernel(const Context& dev_ctx, ...@@ -121,5 +121,22 @@ void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout, const SparseCsrTensor& dout,
SparseCsrTensor* dx); SparseCsrTensor* dx);
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad);
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -225,5 +225,21 @@ SparseCsrTensor ReshapeCsr(const Context& dev_ctx, ...@@ -225,5 +225,21 @@ SparseCsrTensor ReshapeCsr(const Context& dev_ctx,
return csr; return csr;
} }
template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* out);
template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* out);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -38,6 +38,7 @@ from .unary import transpose ...@@ -38,6 +38,7 @@ from .unary import transpose
from .unary import sum from .unary import sum
from .unary import reshape from .unary import reshape
from .unary import isnan from .unary import isnan
from .unary import slice
from .binary import mv from .binary import mv
from .binary import matmul from .binary import matmul
...@@ -87,4 +88,5 @@ __all__ = [ ...@@ -87,4 +88,5 @@ __all__ = [
'is_same_shape', 'is_same_shape',
'reshape', 'reshape',
'isnan', 'isnan',
'slice',
] ]
...@@ -836,3 +836,87 @@ def isnan(x, name=None): ...@@ -836,3 +836,87 @@ def isnan(x, name=None):
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs={} type=op_type, inputs={'x': x}, outputs={'out': out}, attrs={}
) )
return out return out
def slice(x, axes, starts, ends, name=None):
"""
This operator produces a slice of ``x`` along multiple axes for sparse tensors.
Slice uses ``axes``, ``starts`` and ``ends`` attributes to specify the start and
end dimension for each axis in the list of axes and Slice uses this information
to slice the input sparse tensor (x). If a negative value is passed to
``starts`` or ``ends`` such as :math:`-i`, it represents the reverse position of
the axis :math:`i-1` (here 0 is the initial position).
If the value passed to ``starts`` or ``ends`` is greater than the number of elements
in the dimenstion (n), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended to pass
in INT_MAX. The size of ``axes`` must be equal to ``starts`` and ``ends``.
Args:
x (Tensor): The input Tensor (``SparseCooTensor`` or ``SparseCsrTensor``), it's data type should be ``float16``, ``float32``, ``float64``, ``int32``, ``int64``.
axes (list|tuple|Tensor): The data type is ``int32``.If ``axes`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``axes`` is a Tensor, it should be a 1-D Tensor.
Axes that `starts` and `ends` apply to.
starts (list|tuple|Tensor): The data type is ``int32``. If ``starts`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``starts`` is a Tensor, it should be a 1-D Tensor.
It represents starting indices of corresponding axis in ``axes``.
ends (list|tuple|Tensor): The data type is ``int32``. If ``ends`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``ends`` is a Tensor, it should be a 1-D Tensor.
It represents ending indices of corresponding axis in ``axes``.
Returns:
A Sparse Tensor. The data type is same as ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
format = 'coo'
np_x = np.asarray([[4, 0, 7, 0], [0, 0, 5, 0], [-4, 2, 0, 0]])
dense_x = paddle.to_tensor(np_x)
if format == 'coo':
sp_x = dense_x.to_sparse_coo(len(np_x.shape))
else:
sp_x = dense_x.to_sparse_csr()
axes = [0, 1]
starts = [1, 0]
ends = [3, -2]
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
# sp_out is x[1:3, 0:-2]
print(sp_out)
# Tensor(shape=[2, 2], dtype=paddle.int64, place=Place(cpu), stop_gradient=True,
# indices=[[1, 1],
# [0, 1]],
# values=[-4, 2])
"""
if in_dynamic_mode():
return _C_ops.sparse_slice(x, axes, starts, ends)
else:
attrs = {'axes': axes, 'starts': starts, 'ends': ends}
check_variable_and_dtype(
x,
'x',
[
'bool',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'sparse_slice',
)
check_type(axes, 'axes', (list, tuple), 'sparse_slice')
check_type(starts, 'starts', (list, tuple), 'sparse_slice')
check_type(ends, 'ends', (list, tuple), 'sparse_slice')
op_type = 'sparse_slice'
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
data_5d = [
[[2, 3, 4, 5, 6], [0, 1, 2, 4], [0, 1, 2, -4], [3, 3, 4, -2]],
]
data_4d = [
[[2, 3, 4, 5], [0, 1, 2, 3], [0, 1, 2, -4], [3, 3, 4, -2]],
]
data_3d = [
[[4, 4, 5], [-3, -2, -1], [1, -3, 2], [3, 3, 4]],
[[4, 4, 5], [0, 1, 2], [0, 1, 2], [3, 3, 4]],
[[4, 4, 5], [-1], [0], [2]],
[[4, 4, 5], [0], [1], [2]],
[[4, 4, 5], [1], [2], [3]],
[[4, 4, 5], [1, 2], [2, 2], [3, 4]],
[[4, 4, 5], [0, 2], [2, 2], [3, 4]],
]
data_2d = [
[[3, 4], [0], [0], [2]],
[[3, 4], [1], [-3], [2]],
[[3, 4], [-2, -1], [-3, 0], [2, -1]],
[[78, 78], [0, -1], [32, 58], [-2, -1]],
]
devices = ['cpu']
if paddle.device.get_device() != "cpu":
devices.append(paddle.device.get_device())
class TestSparseSlice(unittest.TestCase):
"""
Test the API paddle.sparse.slice on some sparse tensors.
x: sparse, out: sparse
"""
def _check_result(self, np_x, axes, starts, ends, format='coo'):
for device in devices:
paddle.device.set_device(device)
self._check_result_with_place(np_x, axes, starts, ends, format)
def _check_result_with_place(self, np_x, axes, starts, ends, format='coo'):
x_shape = np_x.shape
dense_x = paddle.to_tensor(np_x)
dense_x.stop_gradient = False
dense_out = paddle.slice(dense_x, axes, starts, ends)
if format == 'coo':
sp_x = paddle.to_tensor(np_x).to_sparse_coo(len(x_shape))
else:
sp_x = paddle.to_tensor(np_x).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
np.testing.assert_allclose(
sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-5
)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() * np_x.astype('bool').astype('int'),
rtol=1e-5,
)
def check_result_with_shape(
self, x_shape, axes, starts, ends, format='coo'
):
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
self._check_result(np_x, axes, starts, ends, format)
def check_result_with_list(self, x, axes, starts, ends, format='coo'):
np_x = np.array(x)
self._check_result(np_x, axes, starts, ends, format)
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')
def test_coo_2d(self):
x = [[1, 2, 3, 4], [0, 1, 2, 0]]
self.check_result_with_list(x, [0, 1], [0, 1], [2, 3], format='coo')
for item in data_2d:
self.check_result_with_shape(*item, format='coo')
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
def test_csr_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='csr')
def test_csr_3d_zero(self):
x = [[[0, 0, 1, 2], [0, 0, 0, 2]]]
self.check_result_with_list(x, [1, 2], [0, 0], [2, 2], format='csr')
def test_csr_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='csr')
def test_csr_2d_zero(self):
x = [[0, 0, 1, 2], [0, 0, 0, 1]]
self.check_result_with_list(x, [0, 1], [0, 0], [2, 2], format='csr')
class TestSparseCooSliceStatic(unittest.TestCase):
def _check_result_coo(self, np_x, axes, starts, ends):
for device in devices:
paddle.device.set_device(device)
self._check_result_coo_with_place(np_x, axes, starts, ends)
def _check_result_coo_with_place(self, np_x, axes, starts, ends):
x_shape = np_x.shape
dense_x = paddle.to_tensor(np_x)
dense_x.stop_gradient = False
dense_out = paddle.slice(dense_x, axes, starts, ends)
sp_x = paddle.to_tensor(
np_x,
).to_sparse_coo(len(x_shape))
indices_data = sp_x.detach().indices()
values_data = sp_x.detach().values()
paddle.enable_static()
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
indices = paddle.static.data(
name='indices',
shape=indices_data.shape,
dtype=indices_data.dtype,
)
values = paddle.static.data(
name='values',
shape=values_data.shape,
dtype=values_data.dtype,
)
sp_x = paddle.sparse.sparse_coo_tensor(
indices,
values,
shape=dense_x.shape,
dtype=dense_x.dtype,
)
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
sp_dense_out = sp_out.to_dense()
exe = paddle.static.Executor()
res = exe.run(
feed={
'indices': indices_data.numpy(),
'values': values_data.numpy(),
},
fetch_list=[sp_dense_out],
return_numpy=True,
)
np.testing.assert_allclose(
dense_out.numpy(),
res[0],
rtol=1e-5,
)
paddle.disable_static()
def check_result_with_shape(
self, x_shape, axes, starts, ends, format='coo'
):
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)
def check_result_with_list(self, x, axes, starts, ends, format='coo'):
np_x = np.array(x)
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')
def test_coo_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='coo')
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册