未验证 提交 14c642cb 编写于 作者: Z Zhan Rongrui 提交者: GitHub

【Hackathon 4th No.30】为 Paddle 新增 paddle.sparse.sum 稀疏 API (#51406)

上级 0f1b077b
......@@ -367,6 +367,17 @@
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_op : sum_grad
forward : sum(Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray axis={}, bool keepdim=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sum_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
sum_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : sync_batch_norm_grad
forward : sync_batch_norm_(Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics)
......
......@@ -334,6 +334,17 @@
layout : x
backward : subtract_grad
- op : sum
args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : SumInferMeta
kernel :
func : sum_coo{sparse_coo -> sparse_coo},
sum_csr{sparse_csr -> sparse_csr}
data_type : x
backward : sum_grad
- op : sync_batch_norm_
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
......
......@@ -51,6 +51,7 @@ PD_REGISTER_KERNEL(sum_grad,
float,
double,
phi::dtype::float16,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
......
......@@ -67,6 +67,7 @@ PD_REGISTER_KERNEL(sum_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT, typename Context>
void SumCooGradCPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
unsigned int n_dim = axis.size();
const DenseTensor& x_indices = x.indices();
const DenseTensor& dout_indices = dout.indices();
const DenseTensor& dout_values = dout.values();
const auto* dout_indices_data = dout_indices.data<int64_t>();
const auto* dout_values_data = dout_values.data<T>();
DenseTensor* dx_indices = dx->mutable_indices();
DenseTensor* dx_values = dx->mutable_values();
*dx_indices = x_indices;
const auto* dx_indices_data = dx_indices->data<int64_t>();
auto* dx_values_data = dx_values->data<T>();
phi::funcs::SetConstant<Context, T> set_constant;
if (n_dim == 0) {
T value = dout_values.data<T>()[0];
set_constant(dev_ctx, dx_values, value);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
auto dim = axis[0] < 0 ? x.dims().size() + axis[0] : axis[0];
auto sparse_dim = x.sparse_dim();
if (dim >= sparse_dim) {
dim = dim - sparse_dim + 1;
phi::ReduceSumGradKernel<T, Context>(
dev_ctx, x.values(), dout.values(), {dim}, keep_dim, false, dx_values);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
// Ensure the sparse_dim is not less than 1.
if (sparse_dim == 1) {
keep_dim = true;
}
int64_t dense_dim = 1;
for (auto i = 1; i < x.values().dims().size(); ++i) {
dense_dim *= x.values().dims()[i];
}
std::map<std::vector<IntT>, int64_t> indices_map;
for (auto j = 0; j < dout_indices.dims()[1]; ++j) {
std::vector<IntT> pos;
for (int i = 0; i < dout_indices.dims()[0]; ++i) {
pos.push_back(dout_indices_data[j + i * dout_indices.dims()[1]]);
}
indices_map[pos] = j;
}
for (auto j = 0; j < dx_indices->dims()[1]; ++j) {
std::vector<IntT> pos;
for (int i = 0; i < dx_indices->dims()[0]; ++i) {
if (i != dim) {
pos.push_back(dx_indices_data[j + i * dx_indices->dims()[1]]);
} else if (keep_dim) {
pos.push_back(0);
}
}
for (int i = 0; i < dense_dim; ++i) {
dx_values_data[i + j * dense_dim] =
dout_values_data[i + indices_map[pos] * dense_dim];
}
}
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
}
template <typename T, typename Context>
void SumCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
unsigned int n_dim = axis.size();
const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& dout_values = dout.values();
const auto* x_crows_data = x_crows.data<int64_t>();
DenseTensor* dx_crows = dx->mutable_crows();
DenseTensor* dx_cols = dx->mutable_cols();
DenseTensor* dx_values = dx->mutable_values();
*dx_crows = x_crows;
*dx_cols = x_cols;
phi::funcs::SetConstant<Context, T> set_constant;
if (n_dim == 0) {
T value = dout_values.data<T>()[0];
set_constant(dev_ctx, dx_values, value);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
PADDLE_ENFORCE_EQ(axis[0],
-1,
phi::errors::Unimplemented(
"`axis` of SumCsrKernel only support None or -1 now."
"More number will be supported in the future."));
if (x.dims().size() == 2) {
int value_index = 0;
for (int k = 0; k < x.dims()[0]; ++k) {
if (x_crows_data[k] == x_crows_data[k + 1]) {
continue;
}
T value = dout_values.data<T>()[value_index];
set_constant(dev_ctx, dx_values, value);
value_index += 1;
}
} else {
int dout_value_index = 0;
int dx_value_index = 0;
for (auto batch = 0; batch < x.dims()[0]; ++batch) {
for (auto k = batch * (x.dims()[1] + 1);
k < batch * (x.dims()[1] + 1) + x.dims()[1];
++k) {
if (x_crows_data[k] == x_crows_data[k + 1]) {
continue;
}
T value = dout_values.data<T>()[dout_value_index];
for (auto i = x_crows_data[k]; i < x_crows_data[k + 1]; ++i) {
dx_values->data<T>()[dx_value_index] = value;
dx_value_index++;
}
dout_value_index++;
}
}
}
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
}
template <typename T, typename Context>
void SumCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SumCooGradCPUKernel", ([&] {
SumCooGradCPUKernel<T, data_t, Context>(
dev_ctx, x, dout, axis, keep_dim, dx);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sum_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SumCooGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(sum_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SumCsrGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT, typename Context>
void SumCooCPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out) {
size_t n_dim = axis.size();
auto sparse_dim = x.sparse_dim();
// create out sparse tensor
const auto& x_dims = x.dims();
const auto& x_indices = x.indices();
const auto& x_values = x.values();
DDim out_dims;
DenseTensor out_indices;
DenseTensor out_values;
if (n_dim == 0) {
std::vector<int64_t> out_indices_shape;
if (keep_dim) {
out_dims = make_ddim(std::vector<int64_t>(x_dims.size(), 1));
out_indices_shape = {sparse_dim, 1};
} else {
out_dims = make_ddim({1});
out_indices_shape = {1};
}
out_indices = Empty<IntT, Context>(dev_ctx, out_indices_shape);
auto* out_indices_data = out_indices.data<IntT>();
std::fill(out_indices_data, out_indices_data + out_indices.numel(), 0);
out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, keep_dim);
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
return;
}
auto dim = axis[0] < 0 ? x_dims.size() + axis[0] : axis[0];
const auto* x_indices_data = x_indices.data<IntT>();
const auto* x_values_data = x_values.data<T>();
std::vector<int64_t> dims;
for (int i = 0; i < x.dims().size(); ++i) {
if (i != dim) {
dims.emplace_back(x.dims()[i]);
} else if (keep_dim || (dim < sparse_dim && sparse_dim == 1)) {
dims.emplace_back(1);
}
}
out_dims = make_ddim(dims);
if (dim >= sparse_dim) {
out_indices = x_indices;
dim = dim - sparse_dim + 1;
out_values = phi::Sum<T>(dev_ctx, x.values(), {dim}, dtype, keep_dim);
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
return;
}
// Ensure the sparse_dim is not less than 1.
if (sparse_dim == 1) {
keep_dim = true;
}
// if axis in sparse_dim and keep_dim, sparse_dim will be reduced.
if (!keep_dim) {
sparse_dim -= 1;
}
// indices_map is a mapping from output's position to values to be summed.
std::map<std::vector<IntT>, std::vector<int64_t>> indices_map;
for (int64_t j = 0; j < x_indices.dims()[1]; ++j) {
std::vector<IntT> pos;
for (int64_t i = 0; i < x_indices.dims()[0]; ++i) {
if (dim != i) {
pos.emplace_back(x_indices_data[j + i * x_indices.dims()[1]]);
} else if (keep_dim) {
pos.emplace_back(0);
}
}
indices_map[pos].emplace_back(j);
}
std::vector<int> out_values_dims;
out_values_dims.push_back(static_cast<int>(indices_map.size()));
for (auto i = 1; i < x.values().dims().size(); ++i) {
out_values_dims.push_back(static_cast<int>(x.values().dims()[i]));
}
int64_t dense_dim = std::accumulate(out_values_dims.begin() + 1,
out_values_dims.end(),
1,
std::multiplies<int64_t>());
out_indices = Empty<IntT, Context>(
dev_ctx, {sparse_dim, static_cast<int>(indices_map.size())});
out_values = Empty<T, Context>(dev_ctx, out_values_dims);
auto* out_indices_data = out_indices.data<IntT>();
auto* out_values_data = out_values.data<T>();
auto iter_indices_map = indices_map.begin();
for (size_t j = 0; j < indices_map.size(); ++j) {
std::vector<IntT> pos = iter_indices_map->first;
std::vector<int64_t> values_index = iter_indices_map->second;
iter_indices_map++;
for (auto i = 0; i < sparse_dim; ++i) {
out_indices_data[j + i * indices_map.size()] = pos[i];
}
for (auto i = 0; i < dense_dim; ++i) {
T out_value = 0;
for (auto index : values_index) {
out_value += x_values_data[i + index * dense_dim];
}
out_values_data[i + j * dense_dim] = out_value;
}
}
if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
}
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}
template <typename T, typename Context>
void SumCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCsrTensor* out) {
size_t n_dim = axis.size();
const auto& x_crows = x.crows();
const auto& x_values = x.values();
const auto* x_crows_data = x_crows.data<int64_t>();
const auto* x_values_data = x_values.data<T>();
DenseTensor out_crows, out_cols, out_values;
DDim out_dims;
if (n_dim == 0) {
if (keep_dim && x.dims().size() == 3) {
out_dims = make_ddim({1, 1, 1});
} else {
out_dims = make_ddim({1, 1});
}
out_crows = Empty<int64_t, Context>(dev_ctx, {2}); // crows = [0, 1]
auto* out_crows_data = out_crows.data<int64_t>();
out_crows_data[0] = 0;
out_crows_data[1] = 1;
out_cols = Empty<int64_t, Context>(dev_ctx, {1}); // crows = [0]
auto* out_cols_data = out_cols.data<int64_t>();
out_cols_data[0] = 0;
out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, true);
} else {
PADDLE_ENFORCE_EQ(axis[0],
-1,
phi::errors::Unimplemented(
"`axis` of SumCsrKernel only support None or -1 now."
"More number will be supported in the future."));
out_crows = EmptyLike<int64_t, Context>(dev_ctx, x.crows());
auto* out_crows_data = out_crows.data<int64_t>();
std::vector<T> out_data;
if (x.dims().size() == 2) {
out_crows_data[0] = 0;
out_dims = make_ddim({x.dims()[0], 1});
for (int i = 0; i < x.dims()[0]; ++i) {
if (x_crows_data[i] != x_crows_data[i + 1]) {
T sum_value = 0;
for (auto j = x_crows_data[i]; j < x_crows_data[i + 1]; ++j) {
sum_value += x_values_data[j];
}
out_crows_data[i + 1] = out_crows_data[i] + 1;
out_data.emplace_back(sum_value);
} else {
out_crows_data[i + 1] = out_crows_data[i];
}
}
} else {
if (keep_dim) {
out_dims = make_ddim({x.dims()[0], x.dims()[1], 1});
} else {
out_dims = make_ddim({x.dims()[0], x.dims()[1]});
}
int j = 0;
for (int batch = 0; batch < x.dims()[0]; ++batch) {
auto* cur_x_crows_data = x_crows_data + batch * x.dims()[2];
auto* cur_out_crows_data = out_crows_data + batch * x.dims()[2];
for (int i = 0; i < x.dims()[1]; ++i) {
cur_out_crows_data[0] = 0;
if (cur_x_crows_data[i] != cur_x_crows_data[i + 1]) {
T sum_value = 0;
for (auto k = cur_x_crows_data[i]; k < cur_x_crows_data[i + 1];
++k) {
sum_value += x_values_data[j++];
}
out_data.emplace_back(sum_value);
cur_out_crows_data[i + 1] = cur_out_crows_data[i] + 1;
} else {
cur_out_crows_data[i + 1] = cur_out_crows_data[i];
}
}
}
}
out_cols =
Empty<int64_t, Context>(dev_ctx, {static_cast<int>(out_data.size())});
out_values =
Empty<T, Context>(dev_ctx, {static_cast<int>(out_data.size())});
auto* out_cols_data = out_cols.data<int64_t>();
T* out_values_data = out_values.data<T>();
for (size_t i = 0; i < out_data.size(); ++i) {
out_cols_data[i] = 0;
out_values_data[i] = out_data[i];
}
if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
}
}
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SumCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooCPUKernel", ([&] {
SumCooCPUKernel<T, data_t, Context>(
dev_ctx, x, axis, dtype, keep_dim, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sum_coo,
CPU,
ALL_LAYOUT,
phi::sparse::SumCooKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(sum_csr,
CPU,
ALL_LAYOUT,
phi::sparse::SumCsrKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
namespace phi {
namespace sparse {
template <typename T>
__global__ void SetValueCudaKernel(const T* value,
const int64_t length,
T* data) {
CUDA_KERNEL_LOOP_TYPE(index, length, int64_t) { data[index] = value[0]; }
}
template <typename T>
__global__ void SumCsr2DGradCudaKernel(const int64_t* x_crows_data,
const T* dout_values_data,
const int64_t x_dim0,
T* dx_values_data) {
// dout_crows_data[index] should be equal to index;
CUDA_KERNEL_LOOP_TYPE(index, x_dim0, int64_t) {
T value = dout_values_data[index];
for (auto i = x_crows_data[index]; i < x_crows_data[index + 1]; ++i) {
dx_values_data[i] = value;
}
}
}
template <typename T>
__global__ void SumCsr3DGradCudaKernel(const int64_t* x_crows_data,
const T* dout_values_data,
const int64_t x_dim0,
const int64_t x_dim1,
T* dx_values_data) {
// dout_crows_data[index] should be equal to number;
CUDA_KERNEL_LOOP_TYPE(index, x_dim0 * (x_dim1 + 1), int64_t) {
int64_t batch = index / (x_dim1 + 1);
int64_t number = index % (x_dim1 + 1);
// compute offset of dx_values_data in every batch
int64_t batch_offset = 0;
for (int64_t b = 1; b <= batch; ++b) {
batch_offset += x_crows_data[b * (x_dim1 + 1) - 1];
}
T value = dout_values_data[index - batch];
for (auto i = x_crows_data[index]; i < x_crows_data[index + 1]; ++i) {
dx_values_data[i + batch_offset] = value;
}
}
}
template <typename T, typename IntT, typename Context>
void SumCooGradGPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
unsigned int n_dim = axis.size();
const DenseTensor& x_indices = x.indices();
const DenseTensor& dout_indices = dout.indices();
const DenseTensor& dout_values = dout.values();
const auto* dout_indices_data = dout_indices.data<IntT>();
const auto* dout_values_data = dout_values.data<T>();
DenseTensor* dx_indices = dx->mutable_indices();
DenseTensor* dx_values = dx->mutable_values();
*dx_indices = x_indices;
const auto* dx_indices_data = dx_indices->data<IntT>();
auto* dx_values_data = dx_values->data<T>();
if (n_dim == 0) {
auto length = dx->nnz();
for (auto i = 1; i < x.values().dims().size(); ++i) {
length *= x.values().dims()[i];
}
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, length, 1);
SetValueCudaKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(dout_values_data, length, dx_values_data);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
auto dim = axis[0] < 0 ? x.dims().size() + axis[0] : axis[0];
auto sparse_dim = x.sparse_dim();
if (dim >= sparse_dim) {
dim = dim - sparse_dim + 1;
phi::ReduceSumGradKernel<T, Context>(
dev_ctx, x.values(), dout.values(), {dim}, keep_dim, false, dx_values);
} else {
*dx_values = dout_values;
}
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
}
template <typename T, typename Context>
void SumCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
size_t n_dim = axis.size();
const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& dout_values = dout.values();
DenseTensor* dx_crows = dx->mutable_crows();
DenseTensor* dx_cols = dx->mutable_cols();
DenseTensor* dx_values = dx->mutable_values();
const auto* x_crows_data = x_crows.data<int64_t>();
const auto* dout_values_data = dout_values.data<T>();
auto* dx_values_data = dx_values->data<T>();
*dx_crows = x_crows;
*dx_cols = x_cols;
if (n_dim == 0) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, dx->nnz(), 1);
SetValueCudaKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(dout_values_data, dx->nnz(), dx_values_data);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
PADDLE_ENFORCE_EQ(axis[0],
-1,
phi::errors::Unimplemented(
"`axis` of SumCsrKernel only support None or -1 now."
"More number will be supported in the future."));
if (x.dims().size() == 2) {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x.dims()[0], 1);
SumCsr2DGradCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
x_crows_data, dout_values_data, x.dims()[0], dx_values_data);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, x.dims()[0] * (x.dims()[1] + 1), 1);
SumCsr3DGradCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x_crows_data,
dout_values_data,
x.dims()[0],
x.dims()[1],
dx_values_data);
}
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
}
template <typename T, typename Context>
void SumCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SumCooGradGPUKernel", ([&] {
SumCooGradGPUKernel<T, data_t, Context>(
dev_ctx, x, dout, axis, keep_dim, dx);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sum_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SumCooGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(sum_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SumCsrGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
__global__ void SumCooCudaKernel(const IntT* x_indices_data,
const T* x_values_data,
const int64_t x_nnz,
const int64_t dense_dim,
const int64_t sparse_dim,
const int64_t axis,
const bool keep_dim,
IntT* out_indices_data,
T* out_values_data) {
CUDA_KERNEL_LOOP_TYPE(index_i, x_nnz, int64_t) {
int64_t i = 0;
for (int j = 0; j < dense_dim; ++j) {
out_values_data[j + index_i * dense_dim] = 0;
}
int64_t _index_j_ =
static_cast<int64_t>(blockIdx.y) * blockDim.y + threadIdx.y;
for (auto index_j = _index_j_; index_j < x_nnz;
index_j += static_cast<int64_t>(blockDim.y) * gridDim.y) {
// Determine whether the index_i and index_j elements have the same
// indices in all dimensions except for the specified axis dimension.
bool same = true;
for (int j = 0; j < sparse_dim + !keep_dim; ++j) {
if (j != axis && x_indices_data[index_i + j * x_nnz] !=
x_indices_data[index_j + j * x_nnz]) {
same = false;
break;
}
}
if (same) {
for (int j = 0; j < dense_dim; ++j) {
phi::CudaAtomicAdd(&out_values_data[j + index_i * dense_dim],
x_values_data[j + index_j * dense_dim]);
}
}
}
if (_index_j_ != 0) {
return;
}
if (keep_dim) {
for (int j = 0; j < sparse_dim; ++j) {
if (j == axis) {
out_indices_data[index_i + j * x_nnz] = 0;
} else {
out_indices_data[index_i + j * x_nnz] =
x_indices_data[index_i + j * x_nnz];
}
}
return;
}
for (int j = 0; j < sparse_dim; ++j) {
// out_indices_data [sparse_dim, x.nnz()]
int64_t x_indices_data_offset;
if (j < axis) {
x_indices_data_offset = index_i + j * x_nnz;
} else {
x_indices_data_offset = index_i + (j + 1) * x_nnz;
}
out_indices_data[index_i + j * x_nnz] =
x_indices_data[x_indices_data_offset];
}
}
}
__global__ void SumAllCsrCudaKernel(int64_t* out_crows_data,
int64_t* out_cols_data) {
CUDA_KERNEL_LOOP_TYPE(index, 2, int64_t) {
out_crows_data[index] = index;
if (index == 0) {
out_cols_data[0] = 0;
}
}
}
template <typename T>
__global__ void SumCsr2DCudaKernel(const int64_t* x_crows_data,
const T* x_values_data,
const int64_t x_dim0,
int64_t* out_crows_data,
int64_t* out_cols_data,
T* out_values_data) {
CUDA_KERNEL_LOOP_TYPE(index, x_dim0 + 1, int64_t) {
out_crows_data[index] = index;
if (index != x_dim0) {
out_cols_data[index] = 0;
T sum_value = 0;
for (auto j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) {
sum_value += x_values_data[j];
}
out_values_data[index] = sum_value;
}
}
}
template <typename T>
__global__ void SumCsr3DCudaKernel(const int64_t* x_crows_data,
const T* x_values_data,
const int64_t x_dim0,
const int64_t x_dim1,
const int64_t* batch_nnz_data,
int64_t* out_crows_data,
int64_t* out_cols_data,
T* out_values_data) {
CUDA_KERNEL_LOOP_TYPE(index, x_dim0 * (x_dim1 + 1), int64_t) {
int64_t batch = index / (x_dim1 + 1);
int64_t number = index % (x_dim1 + 1);
out_crows_data[index] = number;
out_cols_data[index] = 0;
if (number != x_dim1) {
T sum_value = 0;
int64_t x_values_data_offset;
if (batch == 0) {
x_values_data_offset = 0;
} else {
x_values_data_offset = batch_nnz_data[batch - 1];
}
for (int64_t j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) {
sum_value += x_values_data[j + x_values_data_offset];
}
out_values_data[index - batch] = sum_value;
}
}
}
template <typename T, typename IntT, typename Context>
void SumCooGPU0Kernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out) {
auto sparse_dim = x.sparse_dim();
// create out sparse tensor
const auto& x_dims = x.dims();
const auto& x_indices = x.indices();
const auto& x_values = x.values();
DDim out_dims;
DenseTensor out_indices;
DenseTensor out_values;
if (keep_dim) {
out_dims = make_ddim(std::vector<int64_t>(x_dims.size(), 1));
out_indices = Empty<IntT, Context>(dev_ctx, {sparse_dim, 1});
} else {
out_dims = make_ddim({1});
out_indices = Empty<IntT, Context>(dev_ctx, {1, 1});
}
phi::funcs::SetConstant<Context, IntT> set_out_indices;
set_out_indices(dev_ctx, &out_indices, static_cast<IntT>(0));
out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, keep_dim);
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}
template <typename T, typename IntT, typename Context>
void SumCooGPU1Kernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out) {
auto sparse_dim = x.sparse_dim();
// create out sparse tensor
const auto& x_dims = x.dims();
const auto& x_indices = x.indices();
const auto& x_values = x.values();
DDim out_dims;
DenseTensor out_indices;
DenseTensor out_values;
auto n_dim = x.dims().size();
auto dim = axis[0] < 0 ? x_dims.size() + axis[0] : axis[0];
std::vector<int64_t> dims;
for (int i = 0; i < n_dim; ++i) {
if (i != dim) {
dims.emplace_back(x.dims()[i]);
} else if (keep_dim || (dim < sparse_dim && sparse_dim == 1)) {
dims.emplace_back(1);
}
}
out_dims = make_ddim(dims);
if (dim >= sparse_dim) {
out_indices = x_indices;
dim = dim - sparse_dim + 1;
out_values = phi::Sum<T>(dev_ctx, x.values(), {dim}, dtype, keep_dim);
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
return;
}
// Ensure the sparse_dim is not less than 1.
if (sparse_dim == 1) {
keep_dim = true;
}
// if axis in sparse_dim and keep_dim, sparse_dim will be reduced.
if (!keep_dim) {
sparse_dim -= 1;
}
std::vector<int> out_values_dims;
out_values_dims.push_back(x.nnz());
for (auto i = 1; i < x.values().dims().size(); ++i) {
out_values_dims.push_back(static_cast<int>(x.values().dims()[i]));
}
int64_t dense_dim = std::accumulate(out_values_dims.begin() + 1,
out_values_dims.end(),
1,
std::multiplies<int64_t>());
out_indices = Empty<IntT, Context>(dev_ctx, {sparse_dim, x.nnz()});
out_values = Empty<T, Context>(dev_ctx, out_values_dims);
const auto* x_indices_data = x_indices.data<IntT>();
const auto* x_values_data = x_values.data<T>();
auto* out_indices_data = out_indices.data<IntT>();
auto* out_values_data = out_values.data<T>();
auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, x.nnz(), x.nnz());
SumCooCudaKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x_indices_data,
x_values_data,
x.nnz(),
dense_dim,
sparse_dim,
dim,
keep_dim,
out_indices_data,
out_values_data);
if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
}
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}
template <typename T, typename Context>
void SumCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out) {
const size_t n_dim = axis.size();
if (n_dim == 0) {
PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] {
SumCooGPU0Kernel<T, data_t, Context>(
dev_ctx, x, axis, dtype, keep_dim, out);
}));
} else {
PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] {
SumCooGPU1Kernel<T, data_t, Context>(
dev_ctx, x, axis, dtype, keep_dim, out);
}));
}
}
template <typename T, typename Context>
void SumCsr0Kernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCsrTensor* out) {
auto x_dim0 = x.dims()[0];
auto x_dim1 = x.dims()[1];
const auto& x_crows = x.crows();
const auto& x_values = x.values();
const auto* x_crows_data = x_crows.data<int64_t>();
const auto* x_values_data = x_values.data<T>();
DenseTensor out_crows, out_cols, out_values;
DDim out_dims;
if (keep_dim && x.dims().size() == 3) {
out_dims = make_ddim({1, 1, 1});
} else {
out_dims = make_ddim({1, 1});
}
out_crows = Empty<int64_t, Context>(dev_ctx, {2}); // crows = [0, 1]
out_cols = Empty<int64_t, Context>(dev_ctx, {1}); // crows = [0]
auto* out_crows_data = out_crows.data<int64_t>();
auto* out_cols_data = out_cols.data<int64_t>();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, 2, 1);
SumAllCsrCudaKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_crows_data, out_cols_data);
out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, true);
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SumCsr1Kernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCsrTensor* out) {
auto x_dim0 = x.dims()[0];
auto x_dim1 = x.dims()[1];
const auto& x_crows = x.crows();
const auto& x_values = x.values();
const auto* x_crows_data = x_crows.data<int64_t>();
const auto* x_values_data = x_values.data<T>();
DenseTensor out_crows, out_cols, out_values;
DDim out_dims;
out_crows = EmptyLike<int64_t, Context>(dev_ctx, x.crows());
auto* out_crows_data = out_crows.data<int64_t>();
if (x.dims().size() == 2) {
out_cols = Empty<int64_t, Context>(dev_ctx, {x_dim0});
out_values = Empty<T, Context>(dev_ctx, {x_dim0});
auto* out_cols_data = out_cols.data<int64_t>();
auto* out_values_data = out_values.data<T>();
out_dims = make_ddim({x_dim0, 1});
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_dim0 + 1, 1);
SumCsr2DCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x_crows_data,
x_values_data,
x_dim0,
out_crows_data,
out_cols_data,
out_values_data);
} else {
out_cols = Empty<int64_t, Context>(dev_ctx, {x_dim0 * x_dim1});
out_values = Empty<T, Context>(dev_ctx, {x_dim0 * x_dim1});
auto* out_cols_data = out_cols.data<int64_t>();
auto* out_values_data = out_values.data<T>();
if (keep_dim) {
out_dims = make_ddim({x_dim0, x_dim1, 1});
} else {
out_dims = make_ddim({x_dim0, x_dim1});
}
DenseTensor x_crows_reshape =
Reshape<int64_t, Context>(dev_ctx, x_crows, {x_dim0, x_dim1 + 1});
DenseTensor last_indices = Empty<int64_t, Context>(dev_ctx, {1});
phi::funcs::SetConstant<Context, int64_t> set_constant;
set_constant(dev_ctx, &last_indices, x_dim1);
DenseTensor x_crows_last = Empty<int64_t, Context>(dev_ctx, {x_dim0, 1});
IndexSelectKernel<int64_t, Context>(
dev_ctx, x_crows_reshape, last_indices, 1, &x_crows_last);
DenseTensor batch_nnz = Empty<int64_t, Context>(dev_ctx, {x_dim0, 1});
CumsumKernel<int64_t, Context>(
dev_ctx, x_crows_last, Scalar(0), false, false, false, &batch_nnz);
auto* batch_nnz_data = batch_nnz.data<int64_t>();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, x.dims()[0] * (x.dims()[1] + 1), 1);
SumCsr3DCudaKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x_crows_data,
x_values_data,
x_dim0,
x_dim1,
batch_nnz_data,
out_crows_data,
out_cols_data,
out_values_data);
}
if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
}
out->SetMember(out_crows, out_cols, out_values, out_dims);
}
template <typename T, typename Context>
void SumCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCsrTensor* out) {
size_t n_dim = axis.size();
if (n_dim == 0) {
SumCsr0Kernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, out);
} else {
PADDLE_ENFORCE_EQ(axis[0],
-1,
phi::errors::Unimplemented(
"`axis` of SumCsrKernel only support None or -1 now."
"More number will be supported in the future."));
SumCsr1Kernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, out);
}
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sum_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SumCooKernel,
float,
double,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(sum_csr,
GPU,
ALL_LAYOUT,
phi::sparse::SumCsrKernel,
float,
double,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
......@@ -92,6 +93,22 @@ void TransposeCsrGradKernel(const Context& dev_ctx,
const std::vector<int>& perm,
SparseCsrTensor* dx);
template <typename T, typename Context>
void SumCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx);
template <typename T, typename Context>
void SumCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCsrTensor* dx);
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
......
......@@ -157,6 +157,22 @@ SparseCsrTensor TransposeCsr(const Context& dev_ctx,
return csr;
}
template <typename T, typename Context>
void SumCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCooTensor* out);
template <typename T, typename Context>
void SumCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const IntArray& axis,
DataType dtype,
bool keep_dim,
SparseCsrTensor* out);
template <typename T, typename Context>
SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) {
SparseCooTensor coo;
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
devices = ['cpu']
if paddle.device.get_device() != "cpu":
devices.append(paddle.device.get_device())
class TestSparseSum(unittest.TestCase):
"""
Test the API paddle.sparse.sum on some sparse tensors.
x: sparse tensor, out: sparse tensor
"""
def to_sparse(self, x, format, sparse_dim=None):
if format == 'coo':
if sparse_dim:
return x.detach().to_sparse_coo(sparse_dim=sparse_dim)
else:
return x.detach().to_sparse_coo(sparse_dim=x.ndim)
elif format == 'csr':
return x.detach().to_sparse_csr()
def check_result(
self, x_shape, dims, keepdim, format, sparse_dim=None, dtype=None
):
for device in devices:
paddle.device.set_device(device)
if sparse_dim:
mask_shape = [*x_shape[:sparse_dim]] + [1] * (
len(x_shape) - sparse_dim
)
mask = paddle.randint(0, 2, mask_shape)
else:
mask = paddle.randint(0, 2, x_shape)
while paddle.sum(mask) == 0:
if sparse_dim:
mask_shape = [*x_shape[:sparse_dim]] + [1] * (
len(x_shape) - sparse_dim
)
mask = paddle.randint(0, 2, mask_shape)
else:
mask = paddle.randint(0, 2, x_shape)
# "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask",
# or the backward checks may fail.
origin_x = (paddle.rand(x_shape, dtype='float64') + 1) * mask
dense_x = origin_x.detach()
dense_x.stop_gradient = False
dense_out = paddle.sum(dense_x, dims, keepdim=keepdim, dtype=dtype)
sp_x = self.to_sparse(origin_x, format, sparse_dim)
sp_x.stop_gradient = False
sp_out = paddle.sparse.sum(sp_x, dims, keepdim=keepdim, dtype=dtype)
np.testing.assert_allclose(
sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05
)
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
(dense_x.grad * mask).numpy(),
rtol=1e-05,
)
def test_sum_1d(self):
self.check_result([5], None, False, 'coo')
self.check_result([5], None, True, 'coo')
self.check_result([5], 0, False, 'coo')
self.check_result([5], 0, True, 'coo')
def test_sum_2d(self):
self.check_result([2, 5], None, False, 'coo', dtype="float32")
self.check_result([2, 5], None, True, 'coo')
self.check_result([2, 5], 0, True, 'coo', dtype="float32")
self.check_result([2, 5], 0, False, 'coo')
self.check_result([2, 5], 1, False, 'coo')
self.check_result([2, 5], None, True, 'csr', dtype="float32")
self.check_result([2, 5], -1, True, 'csr', dtype="float32")
self.check_result([2, 5], 0, False, 'coo')
self.check_result([2, 5], -1, True, 'csr')
def test_sum_3d(self):
self.check_result([6, 2, 3], -1, True, 'csr')
for i in [0, 1, -2, None]:
self.check_result([6, 2, 3], i, False, 'coo')
self.check_result([6, 2, 3], i, True, 'coo')
def test_sum_nd(self):
for i in range(6):
self.check_result([8, 3, 4, 4, 5, 3], i, False, 'coo')
self.check_result([8, 3, 4, 4, 5, 3], i, True, 'coo')
# Randint now only supports access to dimension 0 to 9.
self.check_result([2, 3, 4, 2, 3, 4, 2, 3, 4], i, False, 'coo')
def test_sum_sparse_dim(self):
for i in range(6):
self.check_result([8, 3, 4, 4, 5, 3], i, False, 'coo', sparse_dim=3)
self.check_result([8, 3, 4, 4, 5, 3], i, True, 'coo', sparse_dim=3)
class TestSparseSumStatic(unittest.TestCase):
def check_result_coo(self, x_shape, dims, keepdim, dtype=None):
for device in devices:
paddle.device.set_device(device)
mask = paddle.randint(0, 2, x_shape)
while paddle.sum(mask) == 0:
mask = paddle.randint(0, 2, x_shape)
origin_data = (paddle.rand(x_shape, dtype='float32') + 1) * mask
sparse_data = origin_data.detach().to_sparse_coo(
sparse_dim=len(x_shape)
)
indices_data = sparse_data.indices()
values_data = sparse_data.values()
dense_x = origin_data
dense_out = paddle.sum(dense_x, dims, keepdim=keepdim, dtype=dtype)
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
indices = paddle.static.data(
name='indices',
shape=indices_data.shape,
dtype=indices_data.dtype,
)
values = paddle.static.data(
name='values',
shape=values_data.shape,
dtype=values_data.dtype,
)
sp_x = paddle.sparse.sparse_coo_tensor(
indices,
values,
shape=origin_data.shape,
dtype=origin_data.dtype,
)
sp_out = paddle.sparse.sum(
sp_x, dims, keepdim=keepdim, dtype=dtype
)
sp_dense_out = sp_out.to_dense()
sparse_exe = paddle.static.Executor()
sparse_fetch = sparse_exe.run(
feed={
'indices': indices_data.numpy(),
"values": values_data.numpy(),
},
fetch_list=[sp_dense_out],
return_numpy=True,
)
np.testing.assert_allclose(
dense_out.numpy(), sparse_fetch[0], rtol=1e-5
)
paddle.disable_static()
def test_sum(self):
# 1d
self.check_result_coo([5], None, False)
self.check_result_coo([5], None, True)
self.check_result_coo([5], 0, True)
self.check_result_coo([5], 0, False)
# 2d
self.check_result_coo([2, 5], None, False, dtype="float32")
self.check_result_coo([2, 5], None, True)
self.check_result_coo([2, 5], 0, True, dtype="float32")
self.check_result_coo([2, 5], 0, False)
self.check_result_coo([2, 5], 1, False)
self.check_result_coo([2, 5], 0, False)
# 3d
for i in [0, 1, -2, None]:
self.check_result_coo([6, 2, 3], i, False)
self.check_result_coo([6, 2, 3], i, True)
# nd
for i in range(6):
self.check_result_coo([8, 3, 4, 4, 5, 3], i, False)
self.check_result_coo([8, 3, 4, 4, 5, 3], i, True)
# Randint now only supports access to dimension 0 to 9.
self.check_result_coo([2, 3, 4, 2, 3, 4, 2, 3, 4], i, False)
if __name__ == "__main__":
unittest.main()
......@@ -35,6 +35,7 @@ from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1
from .unary import transpose
from .unary import sum
from .unary import reshape
from .unary import isnan
......@@ -79,6 +80,7 @@ __all__ = [
'add',
'subtract',
'transpose',
'sum',
'multiply',
'divide',
'coalesce',
......
......@@ -15,12 +15,14 @@
import numpy as np
from paddle import _C_ops, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.framework import (
convert_np_dtype_to_dtype_,
core,
dygraph_only,
)
from paddle.fluid.layer_helper import LayerHelper
from paddle.framework import LayerHelper, in_dygraph_mode
__all__ = []
......@@ -155,6 +157,91 @@ def transpose(x, perm, name=None):
return _C_ops.sparse_transpose(x, perm)
def sum(x, axis=None, dtype=None, keepdim=False, name=None):
"""
Computes the sum of sparse tensor elements over the given dimension, requiring x to be a SparseCooTensor or SparseCsrTensor.
Args:
x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int32 or int64.
axis (int|list|tuple, optional): The dimensions along which the sum is performed. If
:attr:`None`, sum all elements of :attr:`x` and return a
Tensor with a single element, otherwise must be in the
range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`,
the dimension to reduce is :math:`rank + axis[i]`.
dtype (str, optional): The dtype of output Tensor. The default value is None, the dtype
of output is the same as input Tensor `x`.
keepdim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result Tensor will have one fewer dimension
than the :attr:`x` unless :attr:`keepdim` is true, default
value is False.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Results of summation operation on the specified axis of input Tensor `x`.
if `x.dtype='bool'` or `x.dtype='int32'`, it's data type is `'int64'`,
otherwise it's data type is the same as `x`.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([[-2., 0.], [1., 2.]])
sparse_x = dense_x.to_sparse_coo(1)
out1 = paddle.sparse.sum(sparse_x) # [1.]
out2 = paddle.sparse.sum(sparse_x, axis=0) # [-1., 2.]
out3 = paddle.sparse.sum(sparse_x, axis=-1) # [-2., 3.]
out4 = paddle.sparse.sum(sparse_x, axis=1, keepdim=True) # [[-2.], [3.]]
"""
dtype_flag = False
if dtype is not None:
dtype_flag = True
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
return _C_ops.sparse_sum(x, axis, dtype, keepdim)
else:
if axis is None:
axis = []
else:
axis = [axis]
attrs = {'axis': axis, 'dtype': dtype, 'keepdim': keepdim}
if dtype_flag:
attrs.update({'in_dtype': x.dtype, 'out_dtype': dtype})
check_variable_and_dtype(
x,
'x',
[
'bool',
'float32',
'float64',
'int16',
'int32',
'int64',
],
'sparse_sum',
)
check_type(
axis, 'axis', (int, list, tuple, type(None), Variable), 'sparse_sum'
)
op_type = 'sparse_sum'
helper = LayerHelper(op_type)
if dtype_flag:
out = helper.create_sparse_variable_for_type_inference(dtype=dtype)
else:
out = helper.create_sparse_variable_for_type_inference(
dtype=x.dtype
)
helper.append_op(
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
)
return out
@dygraph_only
def atan(x, name=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册