diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index c541129f7ffbbb77ddb657dfa2d0a11353fb21d0..a18157ce8f7e318feb45d1c1470d4dc56db9790b 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -367,6 +367,17 @@ func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} +- backward_op : sum_grad + forward : sum(Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray axis={}, bool keepdim=false) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : sum_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + sum_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + - backward_op : sync_batch_norm_grad forward : sync_batch_norm_(Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 85b64e867eef8ea5970c71c77814c4bb70685ec4..41d4aedd66d1bf6f9f97aa640ace2cc693c83977 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -334,6 +334,17 @@ layout : x backward : subtract_grad +- op : sum + args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) + output : Tensor(out) + infer_meta : + func : SumInferMeta + kernel : + func : sum_coo{sparse_coo -> sparse_coo}, + sum_csr{sparse_csr -> sparse_csr} + data_type : x + backward : sum_grad + - op : sync_batch_norm_ args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index 8a5e3812950ece26090980e85d1649a93b98e71a..0771a3e9e72f41b5e934e7fc880bf875f965d11c 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -51,6 +51,7 @@ PD_REGISTER_KERNEL(sum_grad, float, double, phi::dtype::float16, + int16_t, int, int64_t, phi::dtype::complex, diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index 9ee6d530374dbb1b7ed37b3b201553ffa9bcb38a..3e88506f723d35d89f5d6231f51ed65c31ce7f4c 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -67,6 +67,7 @@ PD_REGISTER_KERNEL(sum_grad, double, phi::dtype::float16, phi::dtype::bfloat16, + int16_t, int, int64_t, phi::dtype::complex, diff --git a/paddle/phi/kernels/sparse/cpu/sum_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sum_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..05eb975b7e24e8bfe7c2ec611ca779d19cba9f08 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/sum_grad_kernel.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +template +void SumCooGradCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + unsigned int n_dim = axis.size(); + + const DenseTensor& x_indices = x.indices(); + const DenseTensor& dout_indices = dout.indices(); + const DenseTensor& dout_values = dout.values(); + const auto* dout_indices_data = dout_indices.data(); + const auto* dout_values_data = dout_values.data(); + + DenseTensor* dx_indices = dx->mutable_indices(); + DenseTensor* dx_values = dx->mutable_values(); + *dx_indices = x_indices; + + const auto* dx_indices_data = dx_indices->data(); + auto* dx_values_data = dx_values->data(); + + phi::funcs::SetConstant set_constant; + if (n_dim == 0) { + T value = dout_values.data()[0]; + set_constant(dev_ctx, dx_values, value); + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } + return; + } + + auto dim = axis[0] < 0 ? x.dims().size() + axis[0] : axis[0]; + auto sparse_dim = x.sparse_dim(); + if (dim >= sparse_dim) { + dim = dim - sparse_dim + 1; + phi::ReduceSumGradKernel( + dev_ctx, x.values(), dout.values(), {dim}, keep_dim, false, dx_values); + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } + return; + } + // Ensure the sparse_dim is not less than 1. + if (sparse_dim == 1) { + keep_dim = true; + } + + int64_t dense_dim = 1; + for (auto i = 1; i < x.values().dims().size(); ++i) { + dense_dim *= x.values().dims()[i]; + } + + std::map, int64_t> indices_map; + for (auto j = 0; j < dout_indices.dims()[1]; ++j) { + std::vector pos; + for (int i = 0; i < dout_indices.dims()[0]; ++i) { + pos.push_back(dout_indices_data[j + i * dout_indices.dims()[1]]); + } + indices_map[pos] = j; + } + + for (auto j = 0; j < dx_indices->dims()[1]; ++j) { + std::vector pos; + for (int i = 0; i < dx_indices->dims()[0]; ++i) { + if (i != dim) { + pos.push_back(dx_indices_data[j + i * dx_indices->dims()[1]]); + } else if (keep_dim) { + pos.push_back(0); + } + } + for (int i = 0; i < dense_dim; ++i) { + dx_values_data[i + j * dense_dim] = + dout_values_data[i + indices_map[pos] * dense_dim]; + } + } + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } +} + +template +void SumCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + unsigned int n_dim = axis.size(); + + const DenseTensor& x_crows = x.crows(); + const DenseTensor& x_cols = x.cols(); + const DenseTensor& dout_values = dout.values(); + const auto* x_crows_data = x_crows.data(); + + DenseTensor* dx_crows = dx->mutable_crows(); + DenseTensor* dx_cols = dx->mutable_cols(); + DenseTensor* dx_values = dx->mutable_values(); + + *dx_crows = x_crows; + *dx_cols = x_cols; + + phi::funcs::SetConstant set_constant; + if (n_dim == 0) { + T value = dout_values.data()[0]; + set_constant(dev_ctx, dx_values, value); + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } + return; + } + PADDLE_ENFORCE_EQ(axis[0], + -1, + phi::errors::Unimplemented( + "`axis` of SumCsrKernel only support None or -1 now." + "More number will be supported in the future.")); + + if (x.dims().size() == 2) { + int value_index = 0; + for (int k = 0; k < x.dims()[0]; ++k) { + if (x_crows_data[k] == x_crows_data[k + 1]) { + continue; + } + T value = dout_values.data()[value_index]; + set_constant(dev_ctx, dx_values, value); + value_index += 1; + } + } else { + int dout_value_index = 0; + int dx_value_index = 0; + for (auto batch = 0; batch < x.dims()[0]; ++batch) { + for (auto k = batch * (x.dims()[1] + 1); + k < batch * (x.dims()[1] + 1) + x.dims()[1]; + ++k) { + if (x_crows_data[k] == x_crows_data[k + 1]) { + continue; + } + T value = dout_values.data()[dout_value_index]; + for (auto i = x_crows_data[k]; i < x_crows_data[k + 1]; ++i) { + dx_values->data()[dx_value_index] = value; + dx_value_index++; + } + dout_value_index++; + } + } + } + + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } +} + +template +void SumCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCooTensor* dx) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.indices().dtype(), "SumCooGradCPUKernel", ([&] { + SumCooGradCPUKernel( + dev_ctx, x, dout, axis, keep_dim, dx); + })); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sum_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::SumCooGradKernel, + float, + double, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(sum_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::SumCsrGradKernel, + float, + double, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/sum_kernel.cc b/paddle/phi/kernels/sparse/cpu/sum_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b4b11bea89e454e263f0ef9a09bd57a26b7c0de --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/sum_kernel.cc @@ -0,0 +1,283 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void SumCooCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out) { + size_t n_dim = axis.size(); + auto sparse_dim = x.sparse_dim(); + // create out sparse tensor + const auto& x_dims = x.dims(); + const auto& x_indices = x.indices(); + const auto& x_values = x.values(); + DDim out_dims; + DenseTensor out_indices; + DenseTensor out_values; + if (n_dim == 0) { + std::vector out_indices_shape; + if (keep_dim) { + out_dims = make_ddim(std::vector(x_dims.size(), 1)); + out_indices_shape = {sparse_dim, 1}; + } else { + out_dims = make_ddim({1}); + out_indices_shape = {1}; + } + out_indices = Empty(dev_ctx, out_indices_shape); + auto* out_indices_data = out_indices.data(); + std::fill(out_indices_data, out_indices_data + out_indices.numel(), 0); + out_values = phi::Sum(dev_ctx, x.values(), {}, dtype, keep_dim); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + return; + } + + auto dim = axis[0] < 0 ? x_dims.size() + axis[0] : axis[0]; + const auto* x_indices_data = x_indices.data(); + const auto* x_values_data = x_values.data(); + + std::vector dims; + for (int i = 0; i < x.dims().size(); ++i) { + if (i != dim) { + dims.emplace_back(x.dims()[i]); + } else if (keep_dim || (dim < sparse_dim && sparse_dim == 1)) { + dims.emplace_back(1); + } + } + out_dims = make_ddim(dims); + + if (dim >= sparse_dim) { + out_indices = x_indices; + dim = dim - sparse_dim + 1; + out_values = phi::Sum(dev_ctx, x.values(), {dim}, dtype, keep_dim); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + return; + } + + // Ensure the sparse_dim is not less than 1. + if (sparse_dim == 1) { + keep_dim = true; + } + // if axis in sparse_dim and keep_dim, sparse_dim will be reduced. + if (!keep_dim) { + sparse_dim -= 1; + } + + // indices_map is a mapping from output's position to values to be summed. + std::map, std::vector> indices_map; + for (int64_t j = 0; j < x_indices.dims()[1]; ++j) { + std::vector pos; + for (int64_t i = 0; i < x_indices.dims()[0]; ++i) { + if (dim != i) { + pos.emplace_back(x_indices_data[j + i * x_indices.dims()[1]]); + } else if (keep_dim) { + pos.emplace_back(0); + } + } + indices_map[pos].emplace_back(j); + } + + std::vector out_values_dims; + out_values_dims.push_back(static_cast(indices_map.size())); + for (auto i = 1; i < x.values().dims().size(); ++i) { + out_values_dims.push_back(static_cast(x.values().dims()[i])); + } + int64_t dense_dim = std::accumulate(out_values_dims.begin() + 1, + out_values_dims.end(), + 1, + std::multiplies()); + + out_indices = Empty( + dev_ctx, {sparse_dim, static_cast(indices_map.size())}); + out_values = Empty(dev_ctx, out_values_dims); + + auto* out_indices_data = out_indices.data(); + auto* out_values_data = out_values.data(); + + auto iter_indices_map = indices_map.begin(); + for (size_t j = 0; j < indices_map.size(); ++j) { + std::vector pos = iter_indices_map->first; + std::vector values_index = iter_indices_map->second; + iter_indices_map++; + for (auto i = 0; i < sparse_dim; ++i) { + out_indices_data[j + i * indices_map.size()] = pos[i]; + } + for (auto i = 0; i < dense_dim; ++i) { + T out_value = 0; + for (auto index : values_index) { + out_value += x_values_data[i + index * dense_dim]; + } + out_values_data[i + j * dense_dim] = out_value; + } + } + + if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) { + out_values = phi::Cast(dev_ctx, out_values, dtype); + } + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); +} + +template +void SumCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCsrTensor* out) { + size_t n_dim = axis.size(); + const auto& x_crows = x.crows(); + const auto& x_values = x.values(); + const auto* x_crows_data = x_crows.data(); + const auto* x_values_data = x_values.data(); + + DenseTensor out_crows, out_cols, out_values; + DDim out_dims; + if (n_dim == 0) { + if (keep_dim && x.dims().size() == 3) { + out_dims = make_ddim({1, 1, 1}); + } else { + out_dims = make_ddim({1, 1}); + } + out_crows = Empty(dev_ctx, {2}); // crows = [0, 1] + auto* out_crows_data = out_crows.data(); + out_crows_data[0] = 0; + out_crows_data[1] = 1; + + out_cols = Empty(dev_ctx, {1}); // crows = [0] + auto* out_cols_data = out_cols.data(); + out_cols_data[0] = 0; + out_values = phi::Sum(dev_ctx, x.values(), {}, dtype, true); + } else { + PADDLE_ENFORCE_EQ(axis[0], + -1, + phi::errors::Unimplemented( + "`axis` of SumCsrKernel only support None or -1 now." + "More number will be supported in the future.")); + out_crows = EmptyLike(dev_ctx, x.crows()); + auto* out_crows_data = out_crows.data(); + std::vector out_data; + if (x.dims().size() == 2) { + out_crows_data[0] = 0; + out_dims = make_ddim({x.dims()[0], 1}); + for (int i = 0; i < x.dims()[0]; ++i) { + if (x_crows_data[i] != x_crows_data[i + 1]) { + T sum_value = 0; + for (auto j = x_crows_data[i]; j < x_crows_data[i + 1]; ++j) { + sum_value += x_values_data[j]; + } + out_crows_data[i + 1] = out_crows_data[i] + 1; + out_data.emplace_back(sum_value); + } else { + out_crows_data[i + 1] = out_crows_data[i]; + } + } + } else { + if (keep_dim) { + out_dims = make_ddim({x.dims()[0], x.dims()[1], 1}); + } else { + out_dims = make_ddim({x.dims()[0], x.dims()[1]}); + } + int j = 0; + for (int batch = 0; batch < x.dims()[0]; ++batch) { + auto* cur_x_crows_data = x_crows_data + batch * x.dims()[2]; + auto* cur_out_crows_data = out_crows_data + batch * x.dims()[2]; + for (int i = 0; i < x.dims()[1]; ++i) { + cur_out_crows_data[0] = 0; + if (cur_x_crows_data[i] != cur_x_crows_data[i + 1]) { + T sum_value = 0; + for (auto k = cur_x_crows_data[i]; k < cur_x_crows_data[i + 1]; + ++k) { + sum_value += x_values_data[j++]; + } + out_data.emplace_back(sum_value); + cur_out_crows_data[i + 1] = cur_out_crows_data[i] + 1; + } else { + cur_out_crows_data[i + 1] = cur_out_crows_data[i]; + } + } + } + } + out_cols = + Empty(dev_ctx, {static_cast(out_data.size())}); + out_values = + Empty(dev_ctx, {static_cast(out_data.size())}); + auto* out_cols_data = out_cols.data(); + T* out_values_data = out_values.data(); + for (size_t i = 0; i < out_data.size(); ++i) { + out_cols_data[i] = 0; + out_values_data[i] = out_data[i]; + } + if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) { + out_values = phi::Cast(dev_ctx, out_values, dtype); + } + } + out->SetMember(out_crows, out_cols, out_values, out_dims); +} + +template +void SumCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooCPUKernel", ([&] { + SumCooCPUKernel( + dev_ctx, x, axis, dtype, keep_dim, out); + })); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sum_coo, + CPU, + ALL_LAYOUT, + phi::sparse::SumCooKernel, + float, + double, + int16_t, + int, + int64_t, + bool) { + kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED); +} + +PD_REGISTER_KERNEL(sum_csr, + CPU, + ALL_LAYOUT, + phi::sparse::SumCsrKernel, + float, + double, + int16_t, + int, + int64_t, + bool) { + kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/sparse/gpu/sum_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sum_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b0da1e7ab42f03191ae42609b9d6ca374b2f2c55 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/sum_grad_kernel.cu @@ -0,0 +1,235 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/reduce_sum_grad_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void SetValueCudaKernel(const T* value, + const int64_t length, + T* data) { + CUDA_KERNEL_LOOP_TYPE(index, length, int64_t) { data[index] = value[0]; } +} + +template +__global__ void SumCsr2DGradCudaKernel(const int64_t* x_crows_data, + const T* dout_values_data, + const int64_t x_dim0, + T* dx_values_data) { + // dout_crows_data[index] should be equal to index; + CUDA_KERNEL_LOOP_TYPE(index, x_dim0, int64_t) { + T value = dout_values_data[index]; + for (auto i = x_crows_data[index]; i < x_crows_data[index + 1]; ++i) { + dx_values_data[i] = value; + } + } +} + +template +__global__ void SumCsr3DGradCudaKernel(const int64_t* x_crows_data, + const T* dout_values_data, + const int64_t x_dim0, + const int64_t x_dim1, + T* dx_values_data) { + // dout_crows_data[index] should be equal to number; + CUDA_KERNEL_LOOP_TYPE(index, x_dim0 * (x_dim1 + 1), int64_t) { + int64_t batch = index / (x_dim1 + 1); + int64_t number = index % (x_dim1 + 1); + + // compute offset of dx_values_data in every batch + int64_t batch_offset = 0; + for (int64_t b = 1; b <= batch; ++b) { + batch_offset += x_crows_data[b * (x_dim1 + 1) - 1]; + } + + T value = dout_values_data[index - batch]; + for (auto i = x_crows_data[index]; i < x_crows_data[index + 1]; ++i) { + dx_values_data[i + batch_offset] = value; + } + } +} + +template +void SumCooGradGPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + unsigned int n_dim = axis.size(); + + const DenseTensor& x_indices = x.indices(); + const DenseTensor& dout_indices = dout.indices(); + const DenseTensor& dout_values = dout.values(); + const auto* dout_indices_data = dout_indices.data(); + const auto* dout_values_data = dout_values.data(); + + DenseTensor* dx_indices = dx->mutable_indices(); + DenseTensor* dx_values = dx->mutable_values(); + *dx_indices = x_indices; + + const auto* dx_indices_data = dx_indices->data(); + auto* dx_values_data = dx_values->data(); + + if (n_dim == 0) { + auto length = dx->nnz(); + for (auto i = 1; i < x.values().dims().size(); ++i) { + length *= x.values().dims()[i]; + } + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, length, 1); + + SetValueCudaKernel + <<>>(dout_values_data, length, dx_values_data); + + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } + return; + } + + auto dim = axis[0] < 0 ? x.dims().size() + axis[0] : axis[0]; + auto sparse_dim = x.sparse_dim(); + if (dim >= sparse_dim) { + dim = dim - sparse_dim + 1; + phi::ReduceSumGradKernel( + dev_ctx, x.values(), dout.values(), {dim}, keep_dim, false, dx_values); + } else { + *dx_values = dout_values; + } + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } +} + +template +void SumCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + size_t n_dim = axis.size(); + + const DenseTensor& x_crows = x.crows(); + const DenseTensor& x_cols = x.cols(); + const DenseTensor& dout_values = dout.values(); + + DenseTensor* dx_crows = dx->mutable_crows(); + DenseTensor* dx_cols = dx->mutable_cols(); + DenseTensor* dx_values = dx->mutable_values(); + + const auto* x_crows_data = x_crows.data(); + const auto* dout_values_data = dout_values.data(); + auto* dx_values_data = dx_values->data(); + + *dx_crows = x_crows; + *dx_cols = x_cols; + + if (n_dim == 0) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, dx->nnz(), 1); + SetValueCudaKernel + <<>>(dout_values_data, dx->nnz(), dx_values_data); + + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } + return; + } + PADDLE_ENFORCE_EQ(axis[0], + -1, + phi::errors::Unimplemented( + "`axis` of SumCsrKernel only support None or -1 now." + "More number will be supported in the future.")); + if (x.dims().size() == 2) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x.dims()[0], 1); + SumCsr2DGradCudaKernel<<>>( + x_crows_data, dout_values_data, x.dims()[0], dx_values_data); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, x.dims()[0] * (x.dims()[1] + 1), 1); + SumCsr3DGradCudaKernel<<>>(x_crows_data, + dout_values_data, + x.dims()[0], + x.dims()[1], + dx_values_data); + } + if (dx_values->dtype() != dx->dtype()) { + *dx_values = phi::Cast(dev_ctx, *dx_values, dx->dtype()); + } +} + +template +void SumCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCooTensor* dx) { + PD_VISIT_BASE_INTEGRAL_TYPES( + x.indices().dtype(), "SumCooGradGPUKernel", ([&] { + SumCooGradGPUKernel( + dev_ctx, x, dout, axis, keep_dim, dx); + })); +} +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sum_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SumCooGradKernel, + float, + double, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(sum_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SumCsrGradKernel, + float, + double, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/sum_kernel.cu b/paddle/phi/kernels/sparse/gpu/sum_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..594e1ec48b2e1f2efce92711bdaaa67d802a788a --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/sum_kernel.cu @@ -0,0 +1,456 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void SumCooCudaKernel(const IntT* x_indices_data, + const T* x_values_data, + const int64_t x_nnz, + const int64_t dense_dim, + const int64_t sparse_dim, + const int64_t axis, + const bool keep_dim, + IntT* out_indices_data, + T* out_values_data) { + CUDA_KERNEL_LOOP_TYPE(index_i, x_nnz, int64_t) { + int64_t i = 0; + for (int j = 0; j < dense_dim; ++j) { + out_values_data[j + index_i * dense_dim] = 0; + } + + int64_t _index_j_ = + static_cast(blockIdx.y) * blockDim.y + threadIdx.y; + for (auto index_j = _index_j_; index_j < x_nnz; + index_j += static_cast(blockDim.y) * gridDim.y) { + // Determine whether the index_i and index_j elements have the same + // indices in all dimensions except for the specified axis dimension. + bool same = true; + for (int j = 0; j < sparse_dim + !keep_dim; ++j) { + if (j != axis && x_indices_data[index_i + j * x_nnz] != + x_indices_data[index_j + j * x_nnz]) { + same = false; + break; + } + } + if (same) { + for (int j = 0; j < dense_dim; ++j) { + phi::CudaAtomicAdd(&out_values_data[j + index_i * dense_dim], + x_values_data[j + index_j * dense_dim]); + } + } + } + if (_index_j_ != 0) { + return; + } + if (keep_dim) { + for (int j = 0; j < sparse_dim; ++j) { + if (j == axis) { + out_indices_data[index_i + j * x_nnz] = 0; + } else { + out_indices_data[index_i + j * x_nnz] = + x_indices_data[index_i + j * x_nnz]; + } + } + return; + } + for (int j = 0; j < sparse_dim; ++j) { + // out_indices_data [sparse_dim, x.nnz()] + int64_t x_indices_data_offset; + if (j < axis) { + x_indices_data_offset = index_i + j * x_nnz; + } else { + x_indices_data_offset = index_i + (j + 1) * x_nnz; + } + out_indices_data[index_i + j * x_nnz] = + x_indices_data[x_indices_data_offset]; + } + } +} + +__global__ void SumAllCsrCudaKernel(int64_t* out_crows_data, + int64_t* out_cols_data) { + CUDA_KERNEL_LOOP_TYPE(index, 2, int64_t) { + out_crows_data[index] = index; + if (index == 0) { + out_cols_data[0] = 0; + } + } +} + +template +__global__ void SumCsr2DCudaKernel(const int64_t* x_crows_data, + const T* x_values_data, + const int64_t x_dim0, + int64_t* out_crows_data, + int64_t* out_cols_data, + T* out_values_data) { + CUDA_KERNEL_LOOP_TYPE(index, x_dim0 + 1, int64_t) { + out_crows_data[index] = index; + if (index != x_dim0) { + out_cols_data[index] = 0; + T sum_value = 0; + for (auto j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) { + sum_value += x_values_data[j]; + } + out_values_data[index] = sum_value; + } + } +} + +template +__global__ void SumCsr3DCudaKernel(const int64_t* x_crows_data, + const T* x_values_data, + const int64_t x_dim0, + const int64_t x_dim1, + const int64_t* batch_nnz_data, + int64_t* out_crows_data, + int64_t* out_cols_data, + T* out_values_data) { + CUDA_KERNEL_LOOP_TYPE(index, x_dim0 * (x_dim1 + 1), int64_t) { + int64_t batch = index / (x_dim1 + 1); + int64_t number = index % (x_dim1 + 1); + out_crows_data[index] = number; + out_cols_data[index] = 0; + + if (number != x_dim1) { + T sum_value = 0; + int64_t x_values_data_offset; + if (batch == 0) { + x_values_data_offset = 0; + } else { + x_values_data_offset = batch_nnz_data[batch - 1]; + } + for (int64_t j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) { + sum_value += x_values_data[j + x_values_data_offset]; + } + out_values_data[index - batch] = sum_value; + } + } +} + +template +void SumCooGPU0Kernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out) { + auto sparse_dim = x.sparse_dim(); + // create out sparse tensor + const auto& x_dims = x.dims(); + const auto& x_indices = x.indices(); + const auto& x_values = x.values(); + DDim out_dims; + DenseTensor out_indices; + DenseTensor out_values; + if (keep_dim) { + out_dims = make_ddim(std::vector(x_dims.size(), 1)); + out_indices = Empty(dev_ctx, {sparse_dim, 1}); + } else { + out_dims = make_ddim({1}); + out_indices = Empty(dev_ctx, {1, 1}); + } + phi::funcs::SetConstant set_out_indices; + set_out_indices(dev_ctx, &out_indices, static_cast(0)); + out_values = phi::Sum(dev_ctx, x.values(), {}, dtype, keep_dim); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); +} + +template +void SumCooGPU1Kernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out) { + auto sparse_dim = x.sparse_dim(); + // create out sparse tensor + const auto& x_dims = x.dims(); + const auto& x_indices = x.indices(); + const auto& x_values = x.values(); + DDim out_dims; + DenseTensor out_indices; + DenseTensor out_values; + auto n_dim = x.dims().size(); + auto dim = axis[0] < 0 ? x_dims.size() + axis[0] : axis[0]; + + std::vector dims; + for (int i = 0; i < n_dim; ++i) { + if (i != dim) { + dims.emplace_back(x.dims()[i]); + } else if (keep_dim || (dim < sparse_dim && sparse_dim == 1)) { + dims.emplace_back(1); + } + } + out_dims = make_ddim(dims); + + if (dim >= sparse_dim) { + out_indices = x_indices; + dim = dim - sparse_dim + 1; + out_values = phi::Sum(dev_ctx, x.values(), {dim}, dtype, keep_dim); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + return; + } + + // Ensure the sparse_dim is not less than 1. + if (sparse_dim == 1) { + keep_dim = true; + } + // if axis in sparse_dim and keep_dim, sparse_dim will be reduced. + if (!keep_dim) { + sparse_dim -= 1; + } + + std::vector out_values_dims; + out_values_dims.push_back(x.nnz()); + for (auto i = 1; i < x.values().dims().size(); ++i) { + out_values_dims.push_back(static_cast(x.values().dims()[i])); + } + int64_t dense_dim = std::accumulate(out_values_dims.begin() + 1, + out_values_dims.end(), + 1, + std::multiplies()); + + out_indices = Empty(dev_ctx, {sparse_dim, x.nnz()}); + out_values = Empty(dev_ctx, out_values_dims); + + const auto* x_indices_data = x_indices.data(); + const auto* x_values_data = x_values.data(); + auto* out_indices_data = out_indices.data(); + auto* out_values_data = out_values.data(); + + auto config = + phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, x.nnz(), x.nnz()); + SumCooCudaKernel<<>>(x_indices_data, + x_values_data, + x.nnz(), + dense_dim, + sparse_dim, + dim, + keep_dim, + out_indices_data, + out_values_data); + if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) { + out_values = phi::Cast(dev_ctx, out_values, dtype); + } + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); +} + +template +void SumCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out) { + const size_t n_dim = axis.size(); + if (n_dim == 0) { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] { + SumCooGPU0Kernel( + dev_ctx, x, axis, dtype, keep_dim, out); + })); + } else { + PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] { + SumCooGPU1Kernel( + dev_ctx, x, axis, dtype, keep_dim, out); + })); + } +} + +template +void SumCsr0Kernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCsrTensor* out) { + auto x_dim0 = x.dims()[0]; + auto x_dim1 = x.dims()[1]; + const auto& x_crows = x.crows(); + const auto& x_values = x.values(); + const auto* x_crows_data = x_crows.data(); + const auto* x_values_data = x_values.data(); + + DenseTensor out_crows, out_cols, out_values; + DDim out_dims; + if (keep_dim && x.dims().size() == 3) { + out_dims = make_ddim({1, 1, 1}); + } else { + out_dims = make_ddim({1, 1}); + } + out_crows = Empty(dev_ctx, {2}); // crows = [0, 1] + out_cols = Empty(dev_ctx, {1}); // crows = [0] + auto* out_crows_data = out_crows.data(); + auto* out_cols_data = out_cols.data(); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, 2, 1); + SumAllCsrCudaKernel<<>>(out_crows_data, out_cols_data); + + out_values = phi::Sum(dev_ctx, x.values(), {}, dtype, true); + out->SetMember(out_crows, out_cols, out_values, out_dims); +} + +template +void SumCsr1Kernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCsrTensor* out) { + auto x_dim0 = x.dims()[0]; + auto x_dim1 = x.dims()[1]; + const auto& x_crows = x.crows(); + const auto& x_values = x.values(); + const auto* x_crows_data = x_crows.data(); + const auto* x_values_data = x_values.data(); + + DenseTensor out_crows, out_cols, out_values; + DDim out_dims; + out_crows = EmptyLike(dev_ctx, x.crows()); + auto* out_crows_data = out_crows.data(); + + if (x.dims().size() == 2) { + out_cols = Empty(dev_ctx, {x_dim0}); + out_values = Empty(dev_ctx, {x_dim0}); + auto* out_cols_data = out_cols.data(); + auto* out_values_data = out_values.data(); + out_dims = make_ddim({x_dim0, 1}); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_dim0 + 1, 1); + SumCsr2DCudaKernel<<>>(x_crows_data, + x_values_data, + x_dim0, + out_crows_data, + out_cols_data, + out_values_data); + + } else { + out_cols = Empty(dev_ctx, {x_dim0 * x_dim1}); + out_values = Empty(dev_ctx, {x_dim0 * x_dim1}); + auto* out_cols_data = out_cols.data(); + auto* out_values_data = out_values.data(); + if (keep_dim) { + out_dims = make_ddim({x_dim0, x_dim1, 1}); + } else { + out_dims = make_ddim({x_dim0, x_dim1}); + } + + DenseTensor x_crows_reshape = + Reshape(dev_ctx, x_crows, {x_dim0, x_dim1 + 1}); + DenseTensor last_indices = Empty(dev_ctx, {1}); + phi::funcs::SetConstant set_constant; + set_constant(dev_ctx, &last_indices, x_dim1); + + DenseTensor x_crows_last = Empty(dev_ctx, {x_dim0, 1}); + IndexSelectKernel( + dev_ctx, x_crows_reshape, last_indices, 1, &x_crows_last); + + DenseTensor batch_nnz = Empty(dev_ctx, {x_dim0, 1}); + CumsumKernel( + dev_ctx, x_crows_last, Scalar(0), false, false, false, &batch_nnz); + auto* batch_nnz_data = batch_nnz.data(); + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, x.dims()[0] * (x.dims()[1] + 1), 1); + SumCsr3DCudaKernel<<>>(x_crows_data, + x_values_data, + x_dim0, + x_dim1, + batch_nnz_data, + out_crows_data, + out_cols_data, + out_values_data); + } + if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) { + out_values = phi::Cast(dev_ctx, out_values, dtype); + } + out->SetMember(out_crows, out_cols, out_values, out_dims); +} + +template +void SumCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCsrTensor* out) { + size_t n_dim = axis.size(); + if (n_dim == 0) { + SumCsr0Kernel(dev_ctx, x, axis, dtype, keep_dim, out); + } else { + PADDLE_ENFORCE_EQ(axis[0], + -1, + phi::errors::Unimplemented( + "`axis` of SumCsrKernel only support None or -1 now." + "More number will be supported in the future.")); + SumCsr1Kernel(dev_ctx, x, axis, dtype, keep_dim, out); + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sum_coo, + GPU, + ALL_LAYOUT, + phi::sparse::SumCooKernel, + float, + double, + int, + int64_t) { + kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED); +} + +PD_REGISTER_KERNEL(sum_csr, + GPU, + ALL_LAYOUT, + phi::sparse::SumCsrKernel, + float, + double, + int, + int64_t) { + kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.h b/paddle/phi/kernels/sparse/unary_grad_kernel.h index 7440533057022ee5d2f0cfa69db663bf32cba4da..5893b16f6ba3d827de598d4d4f1c877ec3064cbe 100644 --- a/paddle/phi/kernels/sparse/unary_grad_kernel.h +++ b/paddle/phi/kernels/sparse/unary_grad_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -92,6 +93,22 @@ void TransposeCsrGradKernel(const Context& dev_ctx, const std::vector& perm, SparseCsrTensor* dx); +template +void SumCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCooTensor* dx); + +template +void SumCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + const IntArray& axis, + bool keep_dim, + SparseCsrTensor* dx); + template void ReshapeCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index b219ec07236df8d1a7d686328d4f9475cbc244af..d692f75b59408289050055735caf73a878f70361 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -157,6 +157,22 @@ SparseCsrTensor TransposeCsr(const Context& dev_ctx, return csr; } +template +void SumCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCooTensor* out); + +template +void SumCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const IntArray& axis, + DataType dtype, + bool keep_dim, + SparseCsrTensor* out); + template SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor coo; diff --git a/python/paddle/fluid/tests/unittests/test_sparse_sum_op.py b/python/paddle/fluid/tests/unittests/test_sparse_sum_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3690341c51dc0d051898795b2a4ede10916b5133 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_sum_op.py @@ -0,0 +1,204 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + +devices = ['cpu'] +if paddle.device.get_device() != "cpu": + devices.append(paddle.device.get_device()) + + +class TestSparseSum(unittest.TestCase): + """ + Test the API paddle.sparse.sum on some sparse tensors. + x: sparse tensor, out: sparse tensor + """ + + def to_sparse(self, x, format, sparse_dim=None): + if format == 'coo': + if sparse_dim: + return x.detach().to_sparse_coo(sparse_dim=sparse_dim) + else: + return x.detach().to_sparse_coo(sparse_dim=x.ndim) + elif format == 'csr': + return x.detach().to_sparse_csr() + + def check_result( + self, x_shape, dims, keepdim, format, sparse_dim=None, dtype=None + ): + for device in devices: + paddle.device.set_device(device) + if sparse_dim: + mask_shape = [*x_shape[:sparse_dim]] + [1] * ( + len(x_shape) - sparse_dim + ) + mask = paddle.randint(0, 2, mask_shape) + else: + mask = paddle.randint(0, 2, x_shape) + + while paddle.sum(mask) == 0: + if sparse_dim: + mask_shape = [*x_shape[:sparse_dim]] + [1] * ( + len(x_shape) - sparse_dim + ) + mask = paddle.randint(0, 2, mask_shape) + else: + mask = paddle.randint(0, 2, x_shape) + # "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask", + # or the backward checks may fail. + origin_x = (paddle.rand(x_shape, dtype='float64') + 1) * mask + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_out = paddle.sum(dense_x, dims, keepdim=keepdim, dtype=dtype) + sp_x = self.to_sparse(origin_x, format, sparse_dim) + sp_x.stop_gradient = False + sp_out = paddle.sparse.sum(sp_x, dims, keepdim=keepdim, dtype=dtype) + np.testing.assert_allclose( + sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05 + ) + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose( + sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy(), + rtol=1e-05, + ) + + def test_sum_1d(self): + self.check_result([5], None, False, 'coo') + self.check_result([5], None, True, 'coo') + self.check_result([5], 0, False, 'coo') + self.check_result([5], 0, True, 'coo') + + def test_sum_2d(self): + self.check_result([2, 5], None, False, 'coo', dtype="float32") + self.check_result([2, 5], None, True, 'coo') + self.check_result([2, 5], 0, True, 'coo', dtype="float32") + self.check_result([2, 5], 0, False, 'coo') + self.check_result([2, 5], 1, False, 'coo') + self.check_result([2, 5], None, True, 'csr', dtype="float32") + self.check_result([2, 5], -1, True, 'csr', dtype="float32") + self.check_result([2, 5], 0, False, 'coo') + self.check_result([2, 5], -1, True, 'csr') + + def test_sum_3d(self): + self.check_result([6, 2, 3], -1, True, 'csr') + for i in [0, 1, -2, None]: + self.check_result([6, 2, 3], i, False, 'coo') + self.check_result([6, 2, 3], i, True, 'coo') + + def test_sum_nd(self): + for i in range(6): + self.check_result([8, 3, 4, 4, 5, 3], i, False, 'coo') + self.check_result([8, 3, 4, 4, 5, 3], i, True, 'coo') + # Randint now only supports access to dimension 0 to 9. + self.check_result([2, 3, 4, 2, 3, 4, 2, 3, 4], i, False, 'coo') + + def test_sum_sparse_dim(self): + for i in range(6): + self.check_result([8, 3, 4, 4, 5, 3], i, False, 'coo', sparse_dim=3) + self.check_result([8, 3, 4, 4, 5, 3], i, True, 'coo', sparse_dim=3) + + +class TestSparseSumStatic(unittest.TestCase): + def check_result_coo(self, x_shape, dims, keepdim, dtype=None): + for device in devices: + paddle.device.set_device(device) + mask = paddle.randint(0, 2, x_shape) + while paddle.sum(mask) == 0: + mask = paddle.randint(0, 2, x_shape) + origin_data = (paddle.rand(x_shape, dtype='float32') + 1) * mask + sparse_data = origin_data.detach().to_sparse_coo( + sparse_dim=len(x_shape) + ) + indices_data = sparse_data.indices() + values_data = sparse_data.values() + + dense_x = origin_data + dense_out = paddle.sum(dense_x, dims, keepdim=keepdim, dtype=dtype) + + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + indices = paddle.static.data( + name='indices', + shape=indices_data.shape, + dtype=indices_data.dtype, + ) + values = paddle.static.data( + name='values', + shape=values_data.shape, + dtype=values_data.dtype, + ) + sp_x = paddle.sparse.sparse_coo_tensor( + indices, + values, + shape=origin_data.shape, + dtype=origin_data.dtype, + ) + sp_out = paddle.sparse.sum( + sp_x, dims, keepdim=keepdim, dtype=dtype + ) + sp_dense_out = sp_out.to_dense() + + sparse_exe = paddle.static.Executor() + sparse_fetch = sparse_exe.run( + feed={ + 'indices': indices_data.numpy(), + "values": values_data.numpy(), + }, + fetch_list=[sp_dense_out], + return_numpy=True, + ) + + np.testing.assert_allclose( + dense_out.numpy(), sparse_fetch[0], rtol=1e-5 + ) + paddle.disable_static() + + def test_sum(self): + # 1d + self.check_result_coo([5], None, False) + self.check_result_coo([5], None, True) + self.check_result_coo([5], 0, True) + self.check_result_coo([5], 0, False) + + # 2d + self.check_result_coo([2, 5], None, False, dtype="float32") + self.check_result_coo([2, 5], None, True) + self.check_result_coo([2, 5], 0, True, dtype="float32") + self.check_result_coo([2, 5], 0, False) + self.check_result_coo([2, 5], 1, False) + self.check_result_coo([2, 5], 0, False) + + # 3d + for i in [0, 1, -2, None]: + self.check_result_coo([6, 2, 3], i, False) + self.check_result_coo([6, 2, 3], i, True) + + # nd + for i in range(6): + self.check_result_coo([8, 3, 4, 4, 5, 3], i, False) + self.check_result_coo([8, 3, 4, 4, 5, 3], i, True) + # Randint now only supports access to dimension 0 to 9. + self.check_result_coo([2, 3, 4, 2, 3, 4, 2, 3, 4], i, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/sparse/__init__.py b/python/paddle/sparse/__init__.py index e92a5936c4cfc9bd8616a44db820d9ec44df93ed..99051f7cc67021762f00089d34d7b20ba990ebf3 100644 --- a/python/paddle/sparse/__init__.py +++ b/python/paddle/sparse/__init__.py @@ -35,6 +35,7 @@ from .unary import deg2rad from .unary import rad2deg from .unary import expm1 from .unary import transpose +from .unary import sum from .unary import reshape from .unary import isnan @@ -79,6 +80,7 @@ __all__ = [ 'add', 'subtract', 'transpose', + 'sum', 'multiply', 'divide', 'coalesce', diff --git a/python/paddle/sparse/unary.py b/python/paddle/sparse/unary.py index da1d0b549aa413cd9a85b6d158add50f06d569d9..453980225891a0a2a6b8775f751d3bc995fb2afc 100644 --- a/python/paddle/sparse/unary.py +++ b/python/paddle/sparse/unary.py @@ -15,12 +15,14 @@ import numpy as np from paddle import _C_ops, in_dynamic_mode +from paddle.common_ops_import import Variable +from paddle.fluid.data_feeder import check_type, check_variable_and_dtype from paddle.fluid.framework import ( convert_np_dtype_to_dtype_, core, dygraph_only, ) -from paddle.fluid.layer_helper import LayerHelper +from paddle.framework import LayerHelper, in_dygraph_mode __all__ = [] @@ -155,6 +157,91 @@ def transpose(x, perm, name=None): return _C_ops.sparse_transpose(x, perm) +def sum(x, axis=None, dtype=None, keepdim=False, name=None): + """ + Computes the sum of sparse tensor elements over the given dimension, requiring x to be a SparseCooTensor or SparseCsrTensor. + + Args: + x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int32 or int64. + axis (int|list|tuple, optional): The dimensions along which the sum is performed. If + :attr:`None`, sum all elements of :attr:`x` and return a + Tensor with a single element, otherwise must be in the + range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`, + the dimension to reduce is :math:`rank + axis[i]`. + dtype (str, optional): The dtype of output Tensor. The default value is None, the dtype + of output is the same as input Tensor `x`. + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result Tensor will have one fewer dimension + than the :attr:`x` unless :attr:`keepdim` is true, default + value is False. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Results of summation operation on the specified axis of input Tensor `x`. + if `x.dtype='bool'` or `x.dtype='int32'`, it's data type is `'int64'`, + otherwise it's data type is the same as `x`. + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([[-2., 0.], [1., 2.]]) + sparse_x = dense_x.to_sparse_coo(1) + out1 = paddle.sparse.sum(sparse_x) # [1.] + out2 = paddle.sparse.sum(sparse_x, axis=0) # [-1., 2.] + out3 = paddle.sparse.sum(sparse_x, axis=-1) # [-2., 3.] + out4 = paddle.sparse.sum(sparse_x, axis=1, keepdim=True) # [[-2.], [3.]] + """ + dtype_flag = False + if dtype is not None: + dtype_flag = True + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dygraph_mode(): + return _C_ops.sparse_sum(x, axis, dtype, keepdim) + else: + if axis is None: + axis = [] + else: + axis = [axis] + attrs = {'axis': axis, 'dtype': dtype, 'keepdim': keepdim} + + if dtype_flag: + attrs.update({'in_dtype': x.dtype, 'out_dtype': dtype}) + + check_variable_and_dtype( + x, + 'x', + [ + 'bool', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], + 'sparse_sum', + ) + + check_type( + axis, 'axis', (int, list, tuple, type(None), Variable), 'sparse_sum' + ) + + op_type = 'sparse_sum' + helper = LayerHelper(op_type) + if dtype_flag: + out = helper.create_sparse_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_sparse_variable_for_type_inference( + dtype=x.dtype + ) + helper.append_op( + type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs + ) + return out + + @dygraph_only def atan(x, name=None): """