未验证 提交 d71baff6 编写于 作者: S Scotty 提交者: GitHub

【Hackathon 4th No.29】为 Paddle 新增 paddle.sparse.slice 稀疏 API (#53794)

上级 1f82bc37
......@@ -463,3 +463,14 @@
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax
data_type: query
- backward_op: slice_grad
forward : slice(Tensor x, IntArray axes, IntArray starts, IntArray ends) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray axes, IntArray starts, IntArray ends)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : slice_coo_grad{sparse_coo, sparse_coo -> sparse_coo},
slice_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
......@@ -526,3 +526,15 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad
- op: slice
args : (Tensor x, IntArray axes, IntArray starts, IntArray ends)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : slice_coo{sparse_coo -> sparse_coo},
slice_csr{sparse_csr -> sparse_csr}
layout : x
backward : slice_grad
......@@ -215,5 +215,62 @@ inline DDim GetDecreasedDims(const DDim slice_dims,
return decreased_dims;
}
template <typename T = int64_t>
inline void CheckAndUpdateSparseSliceAttrs(const DDim in_dims,
std::vector<T>* axes,
std::vector<T>* starts,
std::vector<T>* ends) {
int64_t rank = int64_t(in_dims.size());
for (auto& axis : *axes) {
if (axis < 0) {
axis = std::max(int64_t(0), axis + rank);
}
}
PADDLE_ENFORCE_EQ(
axes->size(),
starts->size(),
phi::errors::InvalidArgument(
"The length of axes (%d) and length of starts (%d) should be same.",
axes->size(),
starts->size()));
PADDLE_ENFORCE_EQ(
axes->size(),
ends->size(),
phi::errors::InvalidArgument(
"The length of axes (%d) and length of ends (%d) should be same.",
axes->size(),
ends->size()));
CheckAndUpdateSliceAttrs<T>(in_dims, *axes, starts, ends);
}
inline void ConstructNewSliceAttrs(const phi::DDim& x_dims,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
std::vector<int64_t>* new_axes,
std::vector<int64_t>* new_starts,
std::vector<int64_t>* new_ends) {
for (int64_t i = 0; i < x_dims.size(); ++i) {
int pos = -1;
for (int j = 0; j < static_cast<int>(axes.size()); ++j) {
if (axes[j] == i) {
pos = j;
break;
}
}
if (pos == -1) {
(*new_axes)[i] = i;
(*new_starts)[i] = 0;
(*new_ends)[i] = x_dims[i];
} else {
(*new_axes)[i] = axes[pos];
(*new_starts)[i] = starts[pos];
(*new_ends)[i] = ends[pos];
}
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
// set x_grad
const int64_t out_grad_nnz = out_grad.nnz();
auto sparse_dim = static_cast<int64_t>(out_grad.sparse_dim());
DenseTensor dx_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_indices_data = dx_indices.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
const auto* out_grad_indices_data = out_grad.indices().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
for (int64_t j = 0; j < out_grad_nnz; ++j) {
// set indices
for (int64_t i = 0; i < sparse_dim; ++i) {
dx_indices_data[i * out_grad_nnz + j] =
out_grad_indices_data[i * out_grad_nnz + j];
}
for (size_t ii = 0; ii < axes.size(); ++ii) {
int64_t i = axes[ii];
dx_indices_data[i * out_grad_nnz + j] += starts[ii];
}
// set value
dx_values_data[j] = out_grad_values_data[j];
}
x_grad->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
}
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
template <typename T>
void GetCsrInputGradCrows(const int64_t* out_grad_crows_data,
const int64_t out_grad_n_rows,
const int64_t x_n_rows,
const int64_t rows_start,
int64_t* dx_crows_data,
const int64_t out_grad_crows_offset = 0,
const int64_t dx_crows_offset = 0) {
for (int64_t i = 0; i < x_n_rows + 1; ++i) {
int64_t idx = i + dx_crows_offset;
if (i < rows_start) {
dx_crows_data[idx] = 0;
} else if (i < rows_start + out_grad_n_rows + 1) {
int64_t out_grad_idx = out_grad_crows_offset + (i - rows_start);
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
} else {
int64_t out_grad_idx = out_grad_crows_offset + out_grad_n_rows;
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
}
}
}
template <typename T, typename Context>
void SliceCsrGrad2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t out_grad_nnz = out_grad.nnz();
const int64_t n_rows = x.dims()[0];
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {n_rows + 1});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
// set cols
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_cols_data[i] = out_grad_cols_data[i] + starts[1];
}
// set values
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_values_data[i] = out_grad_values_data[i];
}
// set crows
const int64_t out_grad_n_rows = out_grad.dims()[0];
GetCsrInputGradCrows<T>(out_grad_crows_data,
out_grad_n_rows,
n_rows,
starts[0],
dx_crows_data,
0,
0);
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
}
template <typename T, typename Context>
void SliceCsrGrad3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t dim0 = x.dims()[0], n_rows = x.dims()[1];
const int64_t out_grad_nnz = out_grad.nnz();
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {dim0 * (n_rows + 1)});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
// set cols
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_cols_data[i] = out_grad_cols_data[i] + starts[2];
}
// set values
for (int64_t i = 0; i < out_grad_nnz; ++i) {
dx_values_data[i] = out_grad_values_data[i];
}
// set crows
int64_t out_grad_n_rows = out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
if (i < starts[0] || i >= ends[0]) {
for (int64_t j = 0; j < n_rows + 1; ++j) {
dx_crows_data[i * (n_rows + 1) + j] = 0;
}
} else {
int64_t out_grad_crows_offset = (i - starts[0]) * (out_grad_n_rows + 1);
int64_t dx_crows_offset = i * (n_rows + 1);
GetCsrInputGradCrows<T>(out_grad_crows_data,
out_grad_n_rows,
n_rows,
starts[1],
dx_crows_data,
out_grad_crows_offset,
dx_crows_offset);
}
}
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
}
template <typename T, typename Context>
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// Construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
const int64_t sparse_dim = x_dims.size();
if (sparse_dim == 2) {
SliceCsrGrad2D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else if (sparse_dim == 3) {
SliceCsrGrad3D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice grad for Sparse CSR Tensor only support 2-D or 3-D, but got "
"%d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SliceCooCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);
// Step2: Get out_nnz (the number of non-zero elements in output)
const int64_t x_nnz = x.nnz();
int64_t out_nnz = 0;
const auto* x_indices_data = x.indices().data<int64_t>();
for (int64_t j = 0; j < x_nnz; ++j) {
bool hit = true;
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto item = x_indices_data[axes[ii] * x_nnz + j];
if (!(starts[ii] <= item && item < ends[ii])) {
hit = false;
break;
}
}
if (!hit) continue;
out_nnz++;
}
// Step3: Get the values and indices of output
auto sparse_dim = static_cast<int64_t>(x.sparse_dim());
DenseTensor out_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
auto* out_indices_data = out_indices.data<int64_t>();
auto* out_values_data = out_values.data<T>();
const auto* x_values_data = x.values().data<T>();
int64_t index = 0;
for (int64_t j = 0; j < x_nnz && index < out_nnz; ++j) {
bool hit = true;
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto item = x_indices_data[axes[ii] * x_nnz + j];
if (!(starts[ii] <= item && item < ends[ii])) {
hit = false;
break;
}
}
if (!hit) continue;
// set value
out_values_data[index] = x_values_data[j];
// set coordinate
for (int64_t i = 0; i < sparse_dim; ++i) {
out_indices_data[i * out_nnz + index] = x_indices_data[i * x_nnz + j];
}
for (size_t ii = 0; ii < axes.size(); ++ii) {
auto i = axes[ii];
out_indices_data[i * out_nnz + index] -= starts[ii];
}
index++;
}
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}
template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}
int64_t GetCsrNonZeroNumber(const SparseCsrTensor& x,
const int64_t x_crows_start,
const int64_t x_crows_end,
const int64_t min_col,
const int64_t max_col,
const int64_t x_cols_offset = 0) {
const auto* x_crows_data = x.crows().data<int64_t>();
const auto* x_cols_data = x.cols().data<int64_t>();
int64_t out_nnz = 0;
for (int64_t i = x_crows_start; i < x_crows_end; ++i) {
int64_t st = x_crows_data[i] + x_cols_offset;
int64_t ed = x_crows_data[i + 1] + x_cols_offset;
for (int64_t jj = st; jj < ed; ++jj) {
if (x_cols_data[jj] >= min_col && x_cols_data[jj] < max_col) {
out_nnz++;
}
}
}
return out_nnz;
}
template <typename T>
void GetCsrSubMatrix(const SparseCsrTensor& x,
const int64_t x_crows_start,
const int64_t x_crows_end,
const int64_t min_col,
const int64_t max_col,
DenseTensor* out_crows,
DenseTensor* out_cols,
DenseTensor* out_values,
const int64_t x_cols_offset = 0,
const int64_t out_crows_offset = 0,
const int64_t out_cols_offset = 0) {
const auto* x_crows_data = x.crows().data<int64_t>();
const auto* x_cols_data = x.cols().data<int64_t>();
const auto* x_values_data = x.values().data<T>();
auto* out_crows_data = out_crows->data<int64_t>();
auto* out_cols_data = out_cols->data<int64_t>();
auto* out_values_data = out_values->data<T>();
out_crows_data[out_crows_offset] = 0;
int64_t index = 0, out_n_rows = x_crows_end - x_crows_start;
for (int i = 0; i < out_n_rows; ++i) {
int64_t st = x_crows_data[x_crows_start + i] + x_cols_offset;
int64_t ed = x_crows_data[x_crows_start + i + 1] + x_cols_offset;
for (int64_t jj = st; jj < ed; ++jj) {
if (x_cols_data[jj] >= min_col && x_cols_data[jj] < max_col) {
out_cols_data[out_cols_offset + index] = x_cols_data[jj] - min_col;
out_values_data[out_cols_offset + index] = x_values_data[jj];
index++;
}
}
out_crows_data[out_crows_offset + i + 1] = index;
}
}
template <typename T, typename Context>
void SliceCsrTensor2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const phi::DDim& out_dims,
SparseCsrTensor* out) {
// Step1: Get nnz of out
int64_t out_nnz =
GetCsrNonZeroNumber(x, starts[0], ends[0], starts[1], ends[1], 0);
// Step2: Set out
int64_t out_n_rows = ends[0] - starts[0];
DenseTensor out_crows =
phi::Empty<int64_t, Context>(dev_ctx, {out_n_rows + 1});
DenseTensor out_cols = phi::Empty<int64_t, Context>(dev_ctx, {out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
GetCsrSubMatrix<T>(x,
starts[0],
ends[0],
starts[1],
ends[1],
&out_crows,
&out_cols,
&out_values,
0,
0,
0);
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SliceCsrTensor3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const phi::DDim& out_dims,
SparseCsrTensor* out) {
const auto* x_crows_data = x.crows().data<int64_t>();
// Step1: Get nnz of out
const int64_t x_dim0 = x.dims()[0], x_n_rows = x.dims()[1];
int64_t x_cols_offset = 0, out_nnz = 0;
// all_nnzs stores the nnz along with out_dim0, which will be used to set out.
std::vector<int64_t> all_nnzs(ends[0] - starts[0]);
for (int64_t i = 0; i < x_dim0; ++i) {
if (i >= starts[0] && i < ends[0]) { // slice dim 0
int64_t x_crows_st = i * (x_n_rows + 1) + starts[1];
int64_t x_crows_ed = i * (x_n_rows + 1) + ends[1];
int64_t nnz = GetCsrNonZeroNumber(
x, x_crows_st, x_crows_ed, starts[2], ends[2], x_cols_offset);
out_nnz += nnz;
all_nnzs[i - starts[0]] = nnz;
}
// get the start index in non_zero_cols_
x_cols_offset += x_crows_data[(i + 1) * (x_n_rows + 1) - 1];
}
// Step2: Set out
const int64_t out_dim0 = out_dims[0], out_n_rows = out_dims[1];
DenseTensor out_crows =
phi::Empty<int64_t, Context>(dev_ctx, {out_dim0 * (out_n_rows + 1)});
DenseTensor out_cols = phi::Empty<int64_t, Context>(dev_ctx, {out_nnz});
DenseTensor out_values = phi::Empty<T, Context>(dev_ctx, {out_nnz});
x_cols_offset = 0;
int64_t out_crows_offset = 0, out_cols_offset = 0;
for (int64_t i = 0; i < x_dim0; ++i) {
if (i >= starts[0] && i < ends[0]) { // slice dim 0
int64_t x_crows_start = i * (x_n_rows + 1) + starts[1];
int64_t x_crows_end = i * (x_n_rows + 1) + ends[1];
GetCsrSubMatrix<T>(x,
x_crows_start,
x_crows_end,
starts[2],
ends[2],
&out_crows,
&out_cols,
&out_values,
x_cols_offset,
out_crows_offset,
out_cols_offset);
out_crows_offset += (out_n_rows + 1);
out_cols_offset += all_nnzs[i - starts[0]];
}
x_cols_offset += x_crows_data[(i + 1) * (x_n_rows + 1) - 1];
}
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SliceCsrCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();
// Step1: Infer output dims
auto out_dims = funcs::GetSliceDims<int64_t>(
x_dims, axes, starts, ends, nullptr, nullptr);
// Step2: Construct new axes, starts and ends.
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
// Setp3: Slice csr tensor according to its dimension
if (x_dims.size() == 2) {
SliceCsrTensor2D<T, Context>(
dev_ctx, x, new_axes, new_starts, new_ends, out_dims, out);
} else if (x_dims.size() == 3) {
SliceCsrTensor3D<T, Context>(
dev_ctx, x, new_axes, new_starts, new_ends, out_dims, out);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice for Sparse CSR Tensor only support 2-D or 3-D, but got %d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* out) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update attr
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrCompute<T, Context>(dev_ctx, x, axes_vec, starts_vec, ends_vec, out);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr,
CPU,
ALL_LAYOUT,
phi::sparse::SliceCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
namespace sparse {
template <typename T>
__global__ void GetCooInputGradCudaKernel(const int64_t* out_grad_indices_data,
const T* out_grad_values_data,
const int64_t* axes,
const int64_t* starts,
const int64_t axes_size,
const int64_t sparse_dim,
const int64_t out_grad_nnz,
int64_t* dx_indices_data,
T* dx_values_data) {
CUDA_KERNEL_LOOP_TYPE(j, out_grad_nnz, int64_t) {
// set indices
for (int64_t i = 0; i < sparse_dim; ++i) {
dx_indices_data[i * out_grad_nnz + j] =
out_grad_indices_data[i * out_grad_nnz + j];
}
for (size_t ii = 0; ii < axes_size; ++ii) {
int64_t i = axes[ii];
dx_indices_data[i * out_grad_nnz + j] += starts[ii];
}
// set value
dx_values_data[j] = out_grad_values_data[j];
}
}
template <typename T, typename Context>
void SliceCooGradCompute(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// copy axes to device
auto d_axes_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * axes.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* d_axes = reinterpret_cast<int64_t*>(d_axes_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_axes,
phi::CPUPlace(),
axes.data(),
sizeof(int64_t) * axes.size(),
dev_ctx.stream());
// copy starts to device
auto d_starts_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * starts.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_starts,
phi::CPUPlace(),
starts.data(),
sizeof(int64_t) * starts.size(),
dev_ctx.stream());
// Step2: Set indices and values of x_grad
const int64_t out_grad_nnz = out_grad.nnz();
auto sparse_dim = static_cast<int64_t>(out_grad.sparse_dim());
DenseTensor dx_indices =
phi::Empty<int64_t, Context>(dev_ctx, {sparse_dim, out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_indices_data = dx_indices.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
const auto* out_grad_indices_data = out_grad.indices().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
x_grad->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCooInputGradCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_indices_data,
out_grad_values_data,
d_axes,
d_starts,
axes.size(),
sparse_dim,
out_grad_nnz,
dx_indices_data,
dx_values_data);
}
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// Check and update sparse slice attrs
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCooGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
template <typename T>
__global__ void GetCsrInputColsValuesCudaKernel(
const int64_t* out_grad_cols_data,
const T* out_grad_values_data,
const int64_t out_grad_nnz,
const int64_t cols_start,
int64_t* dx_cols_data,
T* dx_values_data) {
CUDA_KERNEL_LOOP_TYPE(i, out_grad_nnz, int64_t) {
dx_cols_data[i] = out_grad_cols_data[i] + cols_start;
dx_values_data[i] = out_grad_values_data[i];
}
}
__global__ void GetCsrInputCrowsCudaKernel(
const int64_t* out_grad_crows_data,
const int64_t out_grad_n_rows,
const int64_t out_grad_nnz,
const int64_t x_n_rows,
const int64_t rows_start,
const int64_t rows_end,
int64_t* dx_crows_data,
const int64_t dx_crows_offset = 0,
const int64_t out_grad_crows_offset = 0) {
CUDA_KERNEL_LOOP_TYPE(i, x_n_rows + 1, int64_t) {
int64_t idx = i + dx_crows_offset;
if (i < rows_start) {
dx_crows_data[idx] = 0;
} else if (i < rows_start + out_grad_n_rows + 1) {
int64_t out_grad_idx = out_grad_crows_offset + (i - rows_start);
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
} else {
int64_t out_grad_idx = out_grad_crows_offset + out_grad_n_rows;
dx_crows_data[idx] = out_grad_crows_data[out_grad_idx];
}
}
}
template <typename T, typename Context>
void SliceCsrGrad2D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t out_grad_nnz = out_grad.nnz();
const int64_t n_rows = x.dims()[0];
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {n_rows + 1});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
// set cols and values
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCsrInputColsValuesCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_cols_data,
out_grad_values_data,
out_grad_nnz,
starts[1],
dx_cols_data,
dx_values_data);
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_crows_data,
out_grad.dims()[0],
out_grad_nnz,
x.dims()[0],
starts[0],
ends[0],
dx_crows_data,
0,
0);
}
__global__ void GetCsrInputCrowsPart1CudaKernl(const int64_t n_rows,
const int64_t dim0_idx,
int64_t* dx_crows_data) {
CUDA_KERNEL_LOOP_TYPE(j, n_rows + 1, int64_t) {
dx_crows_data[dim0_idx * (n_rows + 1) + j] = 0;
}
}
template <typename T, typename Context>
void SliceCsrGrad3D(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const int64_t dim0 = x.dims()[0], n_rows = x.dims()[1];
const int64_t out_grad_nnz = out_grad.nnz();
const auto* out_grad_crows_data = out_grad.crows().data<int64_t>();
const auto* out_grad_cols_data = out_grad.cols().data<int64_t>();
const auto* out_grad_values_data = out_grad.values().data<T>();
DenseTensor dx_crows = phi::Empty<int64_t>(dev_ctx, {dim0 * (n_rows + 1)});
DenseTensor dx_cols = phi::Empty<int64_t>(dev_ctx, {out_grad_nnz});
DenseTensor dx_values = phi::Empty<T, Context>(dev_ctx, {out_grad_nnz});
auto* dx_crows_data = dx_crows.data<int64_t>();
auto* dx_cols_data = dx_cols.data<int64_t>();
auto* dx_values_data = dx_values.data<T>();
x_grad->SetMember(dx_crows, dx_cols, dx_values, x.dims());
// set cols and values
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_grad_nnz + 1, 1);
GetCsrInputColsValuesCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_cols_data,
out_grad_values_data,
out_grad_nnz,
starts[2],
dx_cols_data,
dx_values_data);
// set crows
int64_t out_grad_n_rows = out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
if (i < starts[0] || i >= ends[0]) {
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsPart1CudaKernl<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
n_rows, i, dx_crows_data);
} else {
int64_t dx_crows_offset = i * (n_rows + 1);
int64_t out_grad_crows_offset = (i - starts[0]) * (out_grad_n_rows + 1);
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n_rows + 1, 1);
GetCsrInputCrowsCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad_crows_data,
out_grad_n_rows,
out_grad_nnz,
n_rows,
starts[1],
ends[1],
dx_crows_data,
dx_crows_offset,
out_grad_crows_offset);
}
}
}
template <typename T, typename Context>
void SliceCsrGradCompute(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
// construct new axes, starts, and ends
std::vector<int64_t> new_axes(3), new_starts(3), new_ends(3);
funcs::ConstructNewSliceAttrs(
x_dims, axes, starts, ends, &new_axes, &new_starts, &new_ends);
const int64_t sparse_dim = x_dims.size();
if (sparse_dim == 2) {
SliceCsrGrad2D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else if (sparse_dim == 3) {
SliceCsrGrad3D<T, Context>(
dev_ctx, x, out_grad, new_axes, new_starts, new_ends, x_grad);
} else {
// throw exception
phi::errors::InvalidArgument(
"Slice grad for Sparse CSR Tensor only support 2-D or 3-D, but got "
"%d-D.",
x_dims.size());
}
}
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad) {
const phi::DDim& x_dims = x.dims();
std::vector<int64_t> axes_vec = axes.GetData();
std::vector<int64_t> starts_vec = starts.GetData();
std::vector<int64_t> ends_vec = ends.GetData();
// update starts and ends
funcs::CheckAndUpdateSparseSliceAttrs<int64_t>(
x_dims, &axes_vec, &starts_vec, &ends_vec);
SliceCsrGradCompute<T, Context>(
dev_ctx, x, out_grad, axes_vec, starts_vec, ends_vec, x_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(slice_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SliceCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(slice_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SliceCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
此差异已折叠。
......@@ -121,5 +121,22 @@ void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
SparseCsrTensor* dx);
template <typename T, typename Context>
void SliceCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* x_grad);
template <typename T, typename Context>
void SliceCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& out_grad,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* x_grad);
} // namespace sparse
} // namespace phi
......@@ -225,5 +225,21 @@ SparseCsrTensor ReshapeCsr(const Context& dev_ctx,
return csr;
}
template <typename T, typename Context>
void SliceCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCooTensor* out);
template <typename T, typename Context>
void SliceCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& axes,
const phi::IntArray& starts,
const phi::IntArray& ends,
SparseCsrTensor* out);
} // namespace sparse
} // namespace phi
......@@ -38,6 +38,7 @@ from .unary import transpose
from .unary import sum
from .unary import reshape
from .unary import isnan
from .unary import slice
from .binary import mv
from .binary import matmul
......@@ -87,4 +88,5 @@ __all__ = [
'is_same_shape',
'reshape',
'isnan',
'slice',
]
......@@ -836,3 +836,87 @@ def isnan(x, name=None):
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs={}
)
return out
def slice(x, axes, starts, ends, name=None):
"""
This operator produces a slice of ``x`` along multiple axes for sparse tensors.
Slice uses ``axes``, ``starts`` and ``ends`` attributes to specify the start and
end dimension for each axis in the list of axes and Slice uses this information
to slice the input sparse tensor (x). If a negative value is passed to
``starts`` or ``ends`` such as :math:`-i`, it represents the reverse position of
the axis :math:`i-1` (here 0 is the initial position).
If the value passed to ``starts`` or ``ends`` is greater than the number of elements
in the dimenstion (n), it represents n.
For slicing to the end of a dimension with unknown size, it is recommended to pass
in INT_MAX. The size of ``axes`` must be equal to ``starts`` and ``ends``.
Args:
x (Tensor): The input Tensor (``SparseCooTensor`` or ``SparseCsrTensor``), it's data type should be ``float16``, ``float32``, ``float64``, ``int32``, ``int64``.
axes (list|tuple|Tensor): The data type is ``int32``.If ``axes`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``axes`` is a Tensor, it should be a 1-D Tensor.
Axes that `starts` and `ends` apply to.
starts (list|tuple|Tensor): The data type is ``int32``. If ``starts`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``starts`` is a Tensor, it should be a 1-D Tensor.
It represents starting indices of corresponding axis in ``axes``.
ends (list|tuple|Tensor): The data type is ``int32``. If ``ends`` is a list or tuple, the elements of
it should be integers or Tensors with shape [1]. If ``ends`` is a Tensor, it should be a 1-D Tensor.
It represents ending indices of corresponding axis in ``axes``.
Returns:
A Sparse Tensor. The data type is same as ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
format = 'coo'
np_x = np.asarray([[4, 0, 7, 0], [0, 0, 5, 0], [-4, 2, 0, 0]])
dense_x = paddle.to_tensor(np_x)
if format == 'coo':
sp_x = dense_x.to_sparse_coo(len(np_x.shape))
else:
sp_x = dense_x.to_sparse_csr()
axes = [0, 1]
starts = [1, 0]
ends = [3, -2]
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
# sp_out is x[1:3, 0:-2]
print(sp_out)
# Tensor(shape=[2, 2], dtype=paddle.int64, place=Place(cpu), stop_gradient=True,
# indices=[[1, 1],
# [0, 1]],
# values=[-4, 2])
"""
if in_dynamic_mode():
return _C_ops.sparse_slice(x, axes, starts, ends)
else:
attrs = {'axes': axes, 'starts': starts, 'ends': ends}
check_variable_and_dtype(
x,
'x',
[
'bool',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'sparse_slice',
)
check_type(axes, 'axes', (list, tuple), 'sparse_slice')
check_type(starts, 'starts', (list, tuple), 'sparse_slice')
check_type(ends, 'ends', (list, tuple), 'sparse_slice')
op_type = 'sparse_slice'
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
data_5d = [
[[2, 3, 4, 5, 6], [0, 1, 2, 4], [0, 1, 2, -4], [3, 3, 4, -2]],
]
data_4d = [
[[2, 3, 4, 5], [0, 1, 2, 3], [0, 1, 2, -4], [3, 3, 4, -2]],
]
data_3d = [
[[4, 4, 5], [-3, -2, -1], [1, -3, 2], [3, 3, 4]],
[[4, 4, 5], [0, 1, 2], [0, 1, 2], [3, 3, 4]],
[[4, 4, 5], [-1], [0], [2]],
[[4, 4, 5], [0], [1], [2]],
[[4, 4, 5], [1], [2], [3]],
[[4, 4, 5], [1, 2], [2, 2], [3, 4]],
[[4, 4, 5], [0, 2], [2, 2], [3, 4]],
]
data_2d = [
[[3, 4], [0], [0], [2]],
[[3, 4], [1], [-3], [2]],
[[3, 4], [-2, -1], [-3, 0], [2, -1]],
[[78, 78], [0, -1], [32, 58], [-2, -1]],
]
devices = ['cpu']
if paddle.device.get_device() != "cpu":
devices.append(paddle.device.get_device())
class TestSparseSlice(unittest.TestCase):
"""
Test the API paddle.sparse.slice on some sparse tensors.
x: sparse, out: sparse
"""
def _check_result(self, np_x, axes, starts, ends, format='coo'):
for device in devices:
paddle.device.set_device(device)
self._check_result_with_place(np_x, axes, starts, ends, format)
def _check_result_with_place(self, np_x, axes, starts, ends, format='coo'):
x_shape = np_x.shape
dense_x = paddle.to_tensor(np_x)
dense_x.stop_gradient = False
dense_out = paddle.slice(dense_x, axes, starts, ends)
if format == 'coo':
sp_x = paddle.to_tensor(np_x).to_sparse_coo(len(x_shape))
else:
sp_x = paddle.to_tensor(np_x).to_sparse_csr()
sp_x.stop_gradient = False
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
np.testing.assert_allclose(
sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-5
)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy() * np_x.astype('bool').astype('int'),
rtol=1e-5,
)
def check_result_with_shape(
self, x_shape, axes, starts, ends, format='coo'
):
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
self._check_result(np_x, axes, starts, ends, format)
def check_result_with_list(self, x, axes, starts, ends, format='coo'):
np_x = np.array(x)
self._check_result(np_x, axes, starts, ends, format)
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')
def test_coo_2d(self):
x = [[1, 2, 3, 4], [0, 1, 2, 0]]
self.check_result_with_list(x, [0, 1], [0, 1], [2, 3], format='coo')
for item in data_2d:
self.check_result_with_shape(*item, format='coo')
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
def test_csr_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='csr')
def test_csr_3d_zero(self):
x = [[[0, 0, 1, 2], [0, 0, 0, 2]]]
self.check_result_with_list(x, [1, 2], [0, 0], [2, 2], format='csr')
def test_csr_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='csr')
def test_csr_2d_zero(self):
x = [[0, 0, 1, 2], [0, 0, 0, 1]]
self.check_result_with_list(x, [0, 1], [0, 0], [2, 2], format='csr')
class TestSparseCooSliceStatic(unittest.TestCase):
def _check_result_coo(self, np_x, axes, starts, ends):
for device in devices:
paddle.device.set_device(device)
self._check_result_coo_with_place(np_x, axes, starts, ends)
def _check_result_coo_with_place(self, np_x, axes, starts, ends):
x_shape = np_x.shape
dense_x = paddle.to_tensor(np_x)
dense_x.stop_gradient = False
dense_out = paddle.slice(dense_x, axes, starts, ends)
sp_x = paddle.to_tensor(
np_x,
).to_sparse_coo(len(x_shape))
indices_data = sp_x.detach().indices()
values_data = sp_x.detach().values()
paddle.enable_static()
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
indices = paddle.static.data(
name='indices',
shape=indices_data.shape,
dtype=indices_data.dtype,
)
values = paddle.static.data(
name='values',
shape=values_data.shape,
dtype=values_data.dtype,
)
sp_x = paddle.sparse.sparse_coo_tensor(
indices,
values,
shape=dense_x.shape,
dtype=dense_x.dtype,
)
sp_out = paddle.sparse.slice(sp_x, axes, starts, ends)
sp_dense_out = sp_out.to_dense()
exe = paddle.static.Executor()
res = exe.run(
feed={
'indices': indices_data.numpy(),
'values': values_data.numpy(),
},
fetch_list=[sp_dense_out],
return_numpy=True,
)
np.testing.assert_allclose(
dense_out.numpy(),
res[0],
rtol=1e-5,
)
paddle.disable_static()
def check_result_with_shape(
self, x_shape, axes, starts, ends, format='coo'
):
mask = np.random.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)
def check_result_with_list(self, x, axes, starts, ends, format='coo'):
np_x = np.array(x)
if format == 'coo':
self._check_result_coo(np_x, axes, starts, ends)
def test_coo_5d(self):
for item in data_5d:
self.check_result_with_shape(*item, format='coo')
def test_coo_4d(self):
for item in data_4d:
self.check_result_with_shape(*item, format='coo')
def test_coo_3d(self):
for item in data_3d:
self.check_result_with_shape(*item, format='coo')
def test_coo_2d(self):
for item in data_2d:
self.check_result_with_shape(*item, format='coo')
def test_coo_1d(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [3], [5], format='coo')
def test_coo_1d_zero(self):
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
self.check_result_with_list(x, [0], [-3], [-1], format='coo')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册