diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index b719ceddc55631290b5c31e65c456f45d5ea76a8..8a1eacd37098882afcd91cea25874d9a74b6f906 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -71,6 +71,7 @@ DECLARE_ACTIVATION_KERNEL(Log1p) DECLARE_ACTIVATION_KERNEL(Round) DECLARE_ACTIVATION_KERNEL(Floor) DECLARE_ACTIVATION_KERNEL(Ceil) +DECLARE_ACTIVATION_KERNEL(Negative) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index bd3e16d54dcad460c066fbe6ea5213f09712039f..aca97340df4e3872663aef866fb301aac3d55c05 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -89,6 +89,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) +DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -182,6 +183,15 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) +PD_REGISTER_KERNEL(negative, + CPU, + ALL_LAYOUT, + phi::NegativeKernel, + float, + double, + int16_t, + int, + int64_t) {} PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel) PD_REGISTER_KERNEL( pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index f481821a7bfcc1a3ba1bb6f15d13ccf618dfdd54..9cfd9c108368590271bf973cddca8086b6c3999b 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1814,6 +1814,14 @@ struct CeilFunctor : public BaseActivationFunctor { } }; +template +struct NegativeFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = -x; + } +}; + template struct ZeroGradFunctor : public BaseActivationFunctor { template +void AllocCsrPtr(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCsrTensor* dx) { + DenseTensor dx_crows = phi::EmptyLike(dev_ctx, x.non_zero_crows()); + DenseTensor dx_cols = phi::EmptyLike(dev_ctx, x.non_zero_cols()); + DenseTensor dx_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + dx->SetMember(dx_crows, dx_cols, dx_values, x.dims()); +} + +template +void AllocCooPtr(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* dx) { + DenseTensor dx_indices = phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor dx_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + dx->SetMember(dx_indices, dx_values, x.dims(), true); +} + +template +void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { + // Special case when y_grad is not needed + if (dx != nullptr && dy == nullptr) { + VLOG(4) << "Special case when dy is not needed"; + AllocCsrPtr(dev_ctx, x, dx); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + } else if (dx == nullptr && dy != nullptr) { + VLOG(4) << "Special case when dx is not needed"; + AllocCsrPtr(dev_ctx, y, dy); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + } else { + AllocCsrPtr(dev_ctx, x, dx); + AllocCsrPtr(dev_ctx, y, dy); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + } +} + +template +void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { + if (dx) { + AllocCsrPtr(dev_ctx, x, dx); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + } + + if (dy) { + AllocCsrPtr(dev_ctx, y, dy); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + phi::NegativeKernel( + dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); + } +} + +template +void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { + if (dx) { + // dout*y + AllocCsrPtr(dev_ctx, x, dx); + sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, y, dx); + } + + if (dy) { + // dout*x + AllocCsrPtr(dev_ctx, y, dy); + sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, x, dy); + } +} + +template +void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& out, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { + if (dx) { + // dout/y + AllocCsrPtr(dev_ctx, x, dx); + sparse::ElementWiseDivideCsrKernel(dev_ctx, dout, y, dx); + } + + if (dy) { + // -dout * out / y + AllocCsrPtr(dev_ctx, y, dy); + CopyCsr(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + phi::NegativeKernel( + dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); + auto tmp = sparse::ElementWiseMultiplyCsr(dev_ctx, *dy, out); + sparse::ElementWiseDivideCsrKernel(dev_ctx, tmp, y, dy); + } +} + +template +void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + // Special case when y_grad is not needed*/ + if (dx != nullptr && dy == nullptr) { + VLOG(4) << "Special case when dy is not needed"; + AllocCooPtr(dev_ctx, x, dx); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + } else if (dx == nullptr && dy != nullptr) { + VLOG(4) << "Special case when dx is not needed"; + AllocCooPtr(dev_ctx, y, dy); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + } else { + AllocCooPtr(dev_ctx, x, dx); + AllocCooPtr(dev_ctx, y, dy); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + } +} + +template +void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + if (dx) { + AllocCooPtr(dev_ctx, x, dx); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + } + + if (dy) { + AllocCooPtr(dev_ctx, y, dy); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + phi::NegativeKernel( + dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); + } +} + +template +void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + if (dx) { + // dout*y + AllocCooPtr(dev_ctx, x, dx); + sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, y, dx); + } + + if (dy) { + // dout*x + AllocCooPtr(dev_ctx, y, dy); + sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, x, dy); + } +} + +template +void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& out, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + if (dx) { + // dout/y + AllocCooPtr(dev_ctx, x, dx); + sparse::ElementWiseDivideCooKernel(dev_ctx, dout, y, dx); + } + + if (dy) { + // -dout * out / y + AllocCooPtr(dev_ctx, y, dy); + CopyCoo(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + phi::NegativeKernel( + dev_ctx, dout.non_zero_elements(), dy->mutable_non_zero_elements()); + auto tmp = sparse::ElementWiseMultiplyCoo(dev_ctx, *dy, out); + sparse::ElementWiseDivideCooKernel(dev_ctx, tmp, y, dy); + } +} +// CPU Kernel end + +// Kernel +template +void ElementWiseDivideCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& out, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "ElementWiseDivideCsrGradCPUKernel", ([&] { + ElementWiseDivideCsrGradCPUKernel( + dev_ctx, x, y, out, dout, dx, dy); + })); +} +template +void ElementWiseDivideCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& out, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "ElementWiseDivideCooGradCPUKernel", ([&] { + ElementWiseDivideCooGradCPUKernel( + dev_ctx, x, y, out, dout, dx, dy); + })); +} + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL(name) \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \ + \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \ + template \ + void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \ + const SparseCsrTensor& x, \ + const SparseCsrTensor& y, \ + const SparseCsrTensor& dout, \ + SparseCsrTensor* dx, \ + SparseCsrTensor* dy) { \ + PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), \ + "ElementWise##name##CsrGradCPUKernel", \ + ([&] { \ + ElementWise##name##CsrGradCPUKernel( \ + dev_ctx, x, y, dout, dx, dy); \ + })); \ + } + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \ + template \ + void ElementWise##name##CooGradKernel(const Context& dev_ctx, \ + const SparseCooTensor& x, \ + const SparseCooTensor& y, \ + const SparseCooTensor& dout, \ + SparseCooTensor* dx, \ + SparseCooTensor* dy) { \ + PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \ + "ElementWise##name##CooGradCPUKernel", \ + ([&] { \ + ElementWise##name##CooGradCPUKernel( \ + dev_ctx, x, y, dout, dx, dy); \ + })); \ + } + +DEFINE_ELEMENTWISE_GRAD_KERNEL(Add) +DEFINE_ELEMENTWISE_GRAD_KERNEL(Subtract) +DEFINE_ELEMENTWISE_GRAD_KERNEL(Multiply) + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(add_csr_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCsrGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(subtract_csr_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseSubtractCsrGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(multiply_csr_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseMultiplyCsrGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(divide_csr_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseDivideCsrGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(add_coo_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCooGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(subtract_coo_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseSubtractCooGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(multiply_coo_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseMultiplyCooGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(divide_coo_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseDivideCooGradKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc8592cbc9d4de51d64baaf41e833e166d60fd1d --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc @@ -0,0 +1,451 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace sparse { + +template +struct BinaryOPWithZeroCompareFunctor { + explicit BinaryOPWithZeroCompareFunctor(Functor functor) + : functor_(functor) {} + inline HOSTDEVICE bool operator()(const T* a, + const T* b, + T* result, + const int64_t len) const { + bool is_zero = true; + for (int64_t i = 0; i < len; ++i) { + result[i] = functor_(a[i], b[i]); + if (result[i] != 0) { + is_zero = false; + } + } + return is_zero; + } + Functor functor_; +}; + +template +void Merge(const IntT el_len, + const IntT* a_index, + const T* a_values, + const IntT len_a, + const IntT* b_index_org, + const T* b_values_org, + const IntT len_b, + const IntT len_b_max, + IntT* c_index, + T* c_values, + IntT& nnz, + const Functor& functor_org, + const bool is_divide) { + IntT a = 0; + IntT b = 0; + nnz = 0; + const IntT* b_index = nullptr; + std::vector b_full_index; + const std::vector zero(el_len, 0); + auto functor = BinaryOPWithZeroCompareFunctor(functor_org); + + std::vector b_values(len_b_max, zero.data()); + for (auto i = 0; i < len_b; ++i) { + b_values[b_index_org[i]] = b_values_org + i * el_len; + } + // if is divide expend b_index_org to b_full_index + if (is_divide) { + b_full_index = std::vector(len_b_max); + for (int64_t j = 0; j < static_cast(b_full_index.size()); ++j) { + b_full_index[j] = j; + } + b_index = b_full_index.data(); + } else { + b_index = b_index_org; + } + // merge + while (a < len_a && b < (is_divide ? len_b_max : len_b)) { + if (a_index[a] == b_index[b]) { + if (!functor(a_values + a * el_len, + b_values[b_index[b]], + c_values + nnz * el_len, + el_len)) { + c_index[nnz] = a_index[a]; + ++nnz; + } + ++a; + ++b; + } + // coordinate x[a] < coordinate y[b] + else if (a_index[a] < b_index[b]) { + if (!functor(a_values + a * el_len, + zero.data(), + c_values + nnz * el_len, + el_len)) { + c_index[nnz] = a_index[a]; + ++nnz; + } + ++a; + } + // coordinate x[a] > coordinate y[b] + else if (a_index[a] > b_index[b]) { + if (!functor(zero.data(), + b_values[b_index[b]], + c_values + nnz * el_len, + el_len)) { + c_index[nnz] = b_index[b]; + ++nnz; + } + ++b; + } + } + // a tail + while (a < len_a) { + if (!functor(a_values + a * el_len, + zero.data(), + c_values + nnz * el_len, + el_len)) { + c_index[nnz] = a_index[a]; + ++nnz; + } + ++a; + } + // b tail + while (b < (is_divide ? len_b_max : len_b)) { + if (!functor(zero.data(), + b_values[b_index[b]], + c_values + nnz * el_len, + el_len)) { + c_index[nnz] = b_index[b]; + ++nnz; + } + ++b; + } +} + +// SparseCooTensor elementwise op, only support same shape tensor now +template +void ElementWiseCooKernelImpl(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + SparseCooTensor* out, + const Functor& functor) { + PADDLE_ENFORCE_EQ(x.dims(), + y.dims(), + phi::errors::InvalidArgument( + "Currently only support same shape elementwise " + "compute. The input tensor X's shape " + "should be identical with Y's shape. But received X's " + "shape = [%s], Y's shape = [%s].", + x.dims(), + y.dims())); + int64_t element_size = 1; + for (auto j = 1; j < x.non_zero_elements().dims().size(); ++j) { + element_size *= x.non_zero_elements().dims()[j]; + } + IntT nnz = 0; + const auto x_values = x.non_zero_elements().data(); + const auto y_values = y.non_zero_elements().data(); + const auto sparse_dim = x.non_zero_indices().dims()[0]; + const bool is_divide = std::is_same>::value; + + int64_t max_len = 1; + for (auto j = 0; j < sparse_dim; ++j) { + max_len *= x.dims()[j]; + } + + std::vector sparse_offsets(sparse_dim), x_indexs(x.nnz()), + y_indexs(y.nnz()); + + phi::funcs::sparse::CalcOffsetsPerDim( + x.dims(), sparse_dim, sparse_offsets.data()); + + phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data(), + sparse_offsets.data(), + x.nnz(), + sparse_dim, + 0, + 1, + x_indexs.data()); + + phi::funcs::sparse::FlattenIndices(y.non_zero_indices().data(), + sparse_offsets.data(), + y.nnz(), + sparse_dim, + 0, + 1, + y_indexs.data()); + + std::vector out_indexs; + std::vector out_values_vec; + if (is_divide) { + out_indexs.reserve(max_len); + } else { + out_indexs.reserve(x.nnz() + y.nnz()); + } + out_values_vec.reserve(max_len * element_size); + + // merge x and y + Merge(element_size, + x_indexs.data(), + x_values, + x_indexs.size(), + y_indexs.data(), + y_values, + y_indexs.size(), + max_len, + out_indexs.data(), + out_values_vec.data(), + nnz, + functor, + is_divide); + + std::vector out_indices_vec; + out_indices_vec.resize(nnz * sparse_dim); + + Dim const_dims; + for (auto i = 0; i < x.dims().size(); i++) { + const_dims[i] = x.dims()[i]; + } + + funcs::sparse::IndexToCoordinate(out_indexs.data(), + const_dims, + nnz, + sparse_dim, + 0, + 1, + out_indices_vec.data()); + + if (nnz == 0) { + phi::DenseTensor out_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + phi::DenseTensor out_values = + phi::EmptyLike(dev_ctx, x.non_zero_elements()); + out->SetMember(out_indices, out_values, x.dims()); + } else { + DenseTensorMeta indices_meta( + paddle::experimental::CppTypeToDataType::Type(), + phi::make_ddim( + {static_cast(sparse_dim), static_cast(nnz)}), + DataLayout::NCHW); + auto indeces_dim = vectorize(slice_ddim( + x.non_zero_elements().dims(), 1, x.non_zero_elements().dims().size())); + indeces_dim.insert(indeces_dim.begin(), nnz); + DenseTensorMeta values_meta( + paddle::experimental::CppTypeToDataType::Type(), + phi::make_ddim(indeces_dim), + DataLayout::NCHW); + phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); + phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); + + std::memcpy(out_indices.data(), + out_indices_vec.data(), + sizeof(IntT) * sparse_dim * nnz); + std::memcpy(out_values.data(), + out_values_vec.data(), + sizeof(T) * nnz * element_size); + + out->SetMember(out_indices, out_values, x.dims()); + } +} + +#define DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(name) \ + template \ + void ElementWise##name##CsrCPUKernel(const Context& dev_ctx, \ + const SparseCsrTensor& x, \ + const SparseCsrTensor& y, \ + SparseCsrTensor* out) { \ + funcs::name##Functor functor; \ + auto coo_x = SparseCsrToCoo(dev_ctx, x); \ + auto coo_y = SparseCsrToCoo(dev_ctx, y); \ + DenseTensor indeces; \ + DenseTensor values; \ + SparseCooTensor coo_out; \ + coo_out.SetMember(indeces, values, x.dims()); \ + ElementWiseCooKernelImpl>( \ + dev_ctx, coo_x, coo_y, &coo_out, functor); \ + *out = SparseCooToCsr(dev_ctx, coo_out); \ + } + +#define DEFINE_CSR_ELEMENTWISE_KERNEL(name) \ + template \ + void ElementWise##name##CsrKernel(const Context& dev_ctx, \ + const SparseCsrTensor& x, \ + const SparseCsrTensor& y, \ + SparseCsrTensor* out) { \ + PD_VISIT_INTEGRAL_TYPES( \ + x.non_zero_crows().dtype(), "ElementWise##name##CsrCPUKernel", ([&] { \ + ElementWise##name##CsrCPUKernel(dev_ctx, x, y, out); \ + })); \ + } + +#define DEFINE_COO_ELEMENTWISE_CPU_KERNEL(name) \ + template \ + void ElementWise##name##CooCPUKernel(const Context& dev_ctx, \ + const SparseCooTensor& x, \ + const SparseCooTensor& y, \ + SparseCooTensor* out) { \ + funcs::name##Functor functor; \ + ElementWiseCooKernelImpl>( \ + dev_ctx, x, y, out, functor); \ + } + +#define DEFINE_COO_ELEMENTWISE_KERNEL(name) \ + template \ + void ElementWise##name##CooKernel(const Context& dev_ctx, \ + const SparseCooTensor& x, \ + const SparseCooTensor& y, \ + SparseCooTensor* out) { \ + PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \ + "ElementWise##name##CooCPUKernel", \ + ([&] { \ + ElementWise##name##CooCPUKernel( \ + dev_ctx, x, y, out); \ + })); \ + } + +DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Add) +DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Subtract) +DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Multiply) +DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(Divide) + +DEFINE_CSR_ELEMENTWISE_KERNEL(Add) +DEFINE_CSR_ELEMENTWISE_KERNEL(Subtract) +DEFINE_CSR_ELEMENTWISE_KERNEL(Multiply) +DEFINE_CSR_ELEMENTWISE_KERNEL(Divide) + +DEFINE_COO_ELEMENTWISE_CPU_KERNEL(Add) +DEFINE_COO_ELEMENTWISE_CPU_KERNEL(Subtract) +DEFINE_COO_ELEMENTWISE_CPU_KERNEL(Multiply) +DEFINE_COO_ELEMENTWISE_CPU_KERNEL(Divide) + +DEFINE_COO_ELEMENTWISE_KERNEL(Add) +DEFINE_COO_ELEMENTWISE_KERNEL(Subtract) +DEFINE_COO_ELEMENTWISE_KERNEL(Multiply) +DEFINE_COO_ELEMENTWISE_KERNEL(Divide) + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(add_csr_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCsrKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(add_coo_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddCooKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(subtract_csr_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseSubtractCsrKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(subtract_coo_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseSubtractCooKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(multiply_csr_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseMultiplyCsrKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(multiply_coo_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseMultiplyCooKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(divide_csr_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseDivideCsrKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(divide_coo_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseDivideCooKernel, + float, + double, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/elementwise_grad_kernel.h b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..df3feb597e3c61dee51173b15f469461b48781c9 --- /dev/null +++ b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD(name) \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, Csr) \ + \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, Coo) + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC(name) \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, Csr) \ + \ + DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, Coo) + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, type) \ + template \ + void ElementWise##name##type##GradKernel(const Context& dev_ctx, \ + const Sparse##type##Tensor& x, \ + const Sparse##type##Tensor& y, \ + const Sparse##type##Tensor& dout, \ + Sparse##type##Tensor* dx, \ + Sparse##type##Tensor* dy); + +#define DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, type) \ + template \ + std::vector ElementWise##name##type##Grad( \ + const Context& dev_ctx, \ + const Sparse##type##Tensor& x, \ + const Sparse##type##Tensor& y, \ + const Sparse##type##Tensor& dout) { \ + Sparse##type##Tensor dx; \ + Sparse##type##Tensor dy; \ + ElementWise##name##type##GradKernel( \ + dev_ctx, x, y, dout, &dx, &dy); \ + return std::vector{dx, dy}; \ + } + +DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD(Add) +DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD(Subtract) +DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD(Multiply) + +DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC(Add) +DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC(Subtract) +DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC(Multiply) + +template +void ElementWiseDivideCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& out, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy); + +template +void ElementWiseDivideCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& out, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy); + +template +std::vector ElementWiseDivideCsrGrad( + const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& out, + const SparseCsrTensor& dout) { + SparseCsrTensor dx; + SparseCsrTensor dy; + ElementWiseDivideCsrGradKernel( + dev_ctx, x, y, out, dout, &dx, &dy); + return std::vector{dx, dy}; +} + +template +std::vector ElementWiseDivideCooGrad( + const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& out, + const SparseCooTensor& dout) { + SparseCooTensor dx; + SparseCooTensor dy; + ElementWiseDivideCooGradKernel( + dev_ctx, x, y, out, dout, &dx, &dy); + return std::vector{dx, dy}; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/elementwise_kernel.h b/paddle/phi/kernels/sparse/elementwise_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..38a0cc44701c29d7373e748f7f666d4c6e4f4660 --- /dev/null +++ b/paddle/phi/kernels/sparse/elementwise_kernel.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +#define DEFINE_ELEMENTWISE_KERNEL_HEAD(name) \ + DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, Csr) \ + \ + DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, Coo) + +#define DEFINE_ELEMENTWISE_KERNEL_FUNC(name) \ + DEFINE_CSR_ELEMENTWISE_KERNEL_FUNC(name) \ + \ + DEFINE_COO_ELEMENTWISE_KERNEL_FUNC(name) + +#define DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, type) \ + template \ + void ElementWise##name##type##Kernel(const Context& dev_ctx, \ + const Sparse##type##Tensor& x, \ + const Sparse##type##Tensor& y, \ + Sparse##type##Tensor* out); + +#define DEFINE_CSR_ELEMENTWISE_KERNEL_FUNC(name) \ + template \ + SparseCsrTensor ElementWise##name##Csr(const Context& dev_ctx, \ + const SparseCsrTensor& x, \ + const SparseCsrTensor& y) { \ + DenseTensor non_zero_crows; \ + DenseTensor non_zero_cols; \ + DenseTensor non_zero_elements; \ + SparseCsrTensor out( \ + non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); \ + ElementWise##name##CsrKernel(dev_ctx, x, y, &out); \ + return out; \ + } + +#define DEFINE_COO_ELEMENTWISE_KERNEL_FUNC(name) \ + template \ + SparseCooTensor ElementWise##name##Coo(const Context& dev_ctx, \ + const SparseCooTensor& x, \ + const SparseCooTensor& y) { \ + DenseTensor non_zero_indices; \ + DenseTensor non_zero_elements; \ + SparseCooTensor out(non_zero_indices, non_zero_elements, x.dims()); \ + ElementWise##name##CooKernel(dev_ctx, x, y, &out); \ + return out; \ + } + +DEFINE_ELEMENTWISE_KERNEL_HEAD(Add) +DEFINE_ELEMENTWISE_KERNEL_HEAD(Subtract) +DEFINE_ELEMENTWISE_KERNEL_HEAD(Multiply) +DEFINE_ELEMENTWISE_KERNEL_HEAD(Divide) + +DEFINE_ELEMENTWISE_KERNEL_FUNC(Add) +DEFINE_ELEMENTWISE_KERNEL_FUNC(Subtract) +DEFINE_ELEMENTWISE_KERNEL_FUNC(Multiply) +DEFINE_ELEMENTWISE_KERNEL_FUNC(Divide) + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index b7d53b31bc3baeeb658f9f298b8a46fe9238694b..ca466780da45037a1515b60aefc543c8b51a306f 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -70,6 +70,10 @@ cc_test( test_sparse_activation_dev_api SRCS test_sparse_activation_dev_api.cc DEPS phi phi_api_utils) +cc_test( + test_sparse_elementwise_dev_api + SRCS test_sparse_elementwise_dev_api.cc + DEPS phi phi_api_utils) cc_test( test_math_function diff --git a/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..50848ae5f1ce7a7688f8b65fcd9f271caa934d48 --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc @@ -0,0 +1,422 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/elementwise_add_grad_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_grad_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_grad_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_grad_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h" +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace tests { + +#define TEST_ELEMENTWISE_OP(name) \ + TEST_ELEMENTWISE_OP_WITH_TYPE(name, Csr) \ + \ + TEST_ELEMENTWISE_OP_WITH_TYPE(name, Coo) + +#define TEST_ELEMENTWISE_OP_WITH_TYPE(name, type) \ + template \ + void TestElementWise##name##type(const Context& dev_ctx_cpu, \ + const Sparse##type##Tensor& x, \ + const Sparse##type##Tensor& y, \ + const DDim& dense_dims) { \ + auto out = sparse::ElementWise##name##type(dev_ctx_cpu, x, y); \ + const DenseTensor denseX = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, x); \ + const DenseTensor denseY = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, y); \ + const DenseTensor denseOut = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, out); \ + auto expectResult = name(dev_ctx_cpu, denseX, denseY); \ + for (int j = 0; j < denseOut.numel(); ++j) { \ + auto actualResultRow = denseOut.template data()[j]; \ + auto expectResultRow = expectResult.template data()[j]; \ + if (std::is_same::value || std::is_same::value) { \ + if (!std::isnan(expectResultRow)) { \ + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); \ + } \ + } else { \ + ASSERT_EQ(expectResultRow, actualResultRow); \ + } \ + } \ + } + +TEST_ELEMENTWISE_OP(Add) +TEST_ELEMENTWISE_OP(Subtract) +TEST_ELEMENTWISE_OP(Multiply) +TEST_ELEMENTWISE_OP(Divide) + +TEST(DEV_API, sparse_elementwise_coo_kernel_double) { + using T = double; + using IntT = int64_t; + for (int epoch = 0; epoch < 100; ++epoch) { + DDim dense_dims = phi::make_ddim({2, 4, 4}); + IntT sparse_dim = 2; + // 32els + std::vector x_dense_data = {0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, + 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, + 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0}; + + std::vector y_dense_data = {0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 3.0, 0.0, + 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0}; + + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + + phi::DenseTensor dense_x( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(paddle::platform::CPUPlace()); + + memcpy(dense_x_data, x_dense_data.data(), x_dense_data.size() * sizeof(T)); + + phi::DenseTensor dense_y( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(paddle::platform::CPUPlace()); + + memcpy(dense_y_data, y_dense_data.data(), y_dense_data.size() * sizeof(T)); + + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + auto coo_x = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto coo_y = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); + + TestElementWiseAddCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); + TestElementWiseSubtractCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); + TestElementWiseMultiplyCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); + TestElementWiseDivideCoo(dev_ctx_cpu, coo_x, coo_y, dense_dims); + } +} + +TEST(DEV_API, sparse_elementwise_csr_kernel_float) { + using T = float; + + DDim dense_dims = phi::make_ddim({6, 4}); + // 24els + std::vector x_dense_data = {0.0, 0.0, 4.0, 2.0, 6.0, 3.0, 0.2, 0.1, + 2.2, 1.1, 4.2, 2.1, 0.4, 0.2, 0.0, 0.0, + 4.4, 2.2, 0.6, 0.3, 2.6, 1.3, 0.0, 0.0}; + std::vector y_dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 3.5, + 0.7, 0.0, 3.5, 0.7, 3.2, 0.1, 0.0, 3.2, + 1.0, 0.0, 1.2, 0.5, 0.7, 3.3, 0.0, 9.0}; + + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + + phi::DenseTensor dense_x( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(paddle::platform::CPUPlace()); + + memcpy(dense_x_data, x_dense_data.data(), x_dense_data.size() * sizeof(T)); + + phi::DenseTensor dense_y( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(paddle::platform::CPUPlace()); + + memcpy(dense_y_data, y_dense_data.data(), y_dense_data.size() * sizeof(T)); + + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + auto csr_x = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + auto csr_y = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); + + TestElementWiseAddCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseSubtractCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseMultiplyCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseDivideCsr(dev_ctx_cpu, csr_x, csr_y, dense_dims); +} + +#define TEST_ELEMENTWISE_OP_GRAD(name) \ + TEST_ELEMENTWISE_OP_GRAD_WITH_TYPE(name, Csr) \ + \ + TEST_ELEMENTWISE_OP_GRAD_WITH_TYPE(name, Coo) + +#define TEST_ELEMENTWISE_OP_GRAD_WITH_TYPE(name, type) \ + template \ + void TestElementWise##name##type##Grad(const Context& dev_ctx_cpu, \ + const Sparse##type##Tensor& x, \ + const Sparse##type##Tensor& y, \ + const DDim& dense_dims) { \ + auto out = sparse::ElementWise##name##type(dev_ctx_cpu, x, y); \ + auto dresult = \ + sparse::ElementWise##name##type##Grad(dev_ctx_cpu, x, y, out); \ + \ + DenseTensor expectdy = phi::Empty( \ + dev_ctx_cpu, \ + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); \ + DenseTensor expectdx = phi::Empty( \ + dev_ctx_cpu, \ + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); \ + \ + phi::name##GradKernel( \ + dev_ctx_cpu, \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, x), \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, y), \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, out), \ + -1, \ + &expectdx, \ + &expectdy); \ + const DenseTensor densedX = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, dresult[0]); \ + const DenseTensor densedY = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, dresult[1]); \ + const DenseTensor denseOut = \ + sparse::Sparse##type##ToDense(dev_ctx_cpu, out); \ + \ + for (int j = 0; j < densedX.numel(); ++j) { \ + auto actualResultRow = densedX.template data()[j]; \ + auto expectResultRow = expectdx.template data()[j]; \ + if (std::is_same::value || std::is_same::value) { \ + if (!std::isnan(expectResultRow)) { \ + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); \ + } \ + } else { \ + ASSERT_EQ(expectResultRow, actualResultRow); \ + } \ + } \ + for (int j = 0; j < densedY.numel(); ++j) { \ + auto actualResultRow = densedY.template data()[j]; \ + auto expectResultRow = expectdy.template data()[j]; \ + if (std::is_same::value || std::is_same::value) { \ + if (!std::isnan(expectResultRow)) { \ + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); \ + } \ + } else { \ + ASSERT_EQ(expectResultRow, actualResultRow); \ + } \ + } \ + } + +TEST_ELEMENTWISE_OP_GRAD(Add) +TEST_ELEMENTWISE_OP_GRAD(Subtract) +TEST_ELEMENTWISE_OP_GRAD(Multiply) + +template +void TestElementWiseDivideCsrGrad(const Context& dev_ctx_cpu, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const DDim& dense_dims) { + auto out = sparse::ElementWiseDivideCsr(dev_ctx_cpu, x, y); + auto dresult = + sparse::ElementWiseDivideCsrGrad(dev_ctx_cpu, x, y, out, out); + DenseTensor expectdy = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + DenseTensor expectdx = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + phi::DivideGradKernel(dev_ctx_cpu, + sparse::SparseCsrToDense(dev_ctx_cpu, x), + sparse::SparseCsrToDense(dev_ctx_cpu, y), + sparse::SparseCsrToDense(dev_ctx_cpu, out), + sparse::SparseCsrToDense(dev_ctx_cpu, out), + -1, + &expectdx, + &expectdy); + const DenseTensor densedX = + sparse::SparseCsrToDense(dev_ctx_cpu, dresult[0]); + const DenseTensor densedY = + sparse::SparseCsrToDense(dev_ctx_cpu, dresult[1]); + const DenseTensor denseOut = sparse::SparseCsrToDense(dev_ctx_cpu, out); + for (int j = 0; j < densedX.numel(); ++j) { + auto actualResultRow = densedX.template data()[j]; + auto expectResultRow = expectdx.template data()[j]; + if (!std::isnan(expectResultRow)) { + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); + } + } + for (int j = 0; j < densedY.numel(); ++j) { + auto actualResultRow = densedY.template data()[j]; + auto expectResultRow = expectdy.template data()[j]; + if (!std::isnan(expectResultRow)) { + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); + } + } +} + +template +void TestElementWiseDivideCooGrad(const Context& dev_ctx_cpu, + const SparseCooTensor& x, + const SparseCooTensor& y, + const DDim& dense_dims) { + auto out = sparse::ElementWiseDivideCoo(dev_ctx_cpu, x, y); + auto dresult = + sparse::ElementWiseDivideCooGrad(dev_ctx_cpu, x, y, out, out); + DenseTensor expectdy = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + DenseTensor expectdx = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + phi::DivideGradKernel(dev_ctx_cpu, + sparse::SparseCooToDense(dev_ctx_cpu, x), + sparse::SparseCooToDense(dev_ctx_cpu, y), + sparse::SparseCooToDense(dev_ctx_cpu, out), + sparse::SparseCooToDense(dev_ctx_cpu, out), + -1, + &expectdx, + &expectdy); + const DenseTensor densedX = + sparse::SparseCooToDense(dev_ctx_cpu, dresult[0]); + const DenseTensor densedY = + sparse::SparseCooToDense(dev_ctx_cpu, dresult[1]); + const DenseTensor denseOut = sparse::SparseCooToDense(dev_ctx_cpu, out); + for (int j = 0; j < densedX.numel(); ++j) { + auto actualResultRow = densedX.template data()[j]; + auto expectResultRow = expectdx.template data()[j]; + if (!std::isnan(expectResultRow)) { + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); + } + } + for (int j = 0; j < densedY.numel(); ++j) { + auto actualResultRow = densedY.template data()[j]; + auto expectResultRow = expectdy.template data()[j]; + if (!std::isnan(expectResultRow)) { + ASSERT_DOUBLE_EQ(expectResultRow, actualResultRow); + } + } +} + +TEST(DEV_API, sparse_elementwise_csr_grad_kernel_float) { + using T = float; + DDim dense_dims = phi::make_ddim({2, 3, 4}); + + std::vector x_dense_data = {0.0, 0.0, 4.0, 2.0, 6.0, 3.0, 0.2, 0.1, + 2.2, 1.1, 4.2, 2.1, 0.4, 0.2, 0.0, 0.0, + 4.4, 2.2, 0.6, 0.3, 2.6, 1.3, 0.0, 0.0}; + + std::vector y_dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 3.5, + 0.7, 0.0, 3.5, 0.7, 3.2, 0.1, 0.0, 3.2, + 1.0, 0.0, 1.2, 0.5, 0.7, 3.3, 0.0, 9.0}; + + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + + phi::DenseTensor dense_x( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(paddle::platform::CPUPlace()); + memcpy(dense_x_data, x_dense_data.data(), x_dense_data.size() * sizeof(T)); + + phi::DenseTensor dense_y( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(paddle::platform::CPUPlace()); + memcpy(dense_y_data, y_dense_data.data(), y_dense_data.size() * sizeof(T)); + + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + auto csr_x = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + auto csr_y = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); + + auto dx = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_y); + auto dy = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); + + TestElementWiseAddCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseSubtractCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseMultiplyCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseDivideCsrGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); +} + +TEST(DEV_API, sparse_elementwise_coo_grad_kernel_double) { + using T = double; + int64_t sparse_dim = 2; + DDim dense_dims = phi::make_ddim({3, 4}); + std::vector x_dense_data = { + 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0, 3.2, 0.0, 0.0}; + std::vector y_dense_data = { + 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 3.5, 0.7, 0.0, 3.5, 0.7}; + + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + + phi::DenseTensor dense_x( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(paddle::platform::CPUPlace()); + memcpy(dense_x_data, x_dense_data.data(), x_dense_data.size() * sizeof(T)); + + phi::DenseTensor dense_y( + alloc.get(), + phi::DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(paddle::platform::CPUPlace()); + memcpy(dense_y_data, y_dense_data.data(), y_dense_data.size() * sizeof(T)); + + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + auto csr_x = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + auto csr_y = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); + + auto dx = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_y, sparse_dim); + auto dy = sparse::DenseToSparseCoo(dev_ctx_cpu, dense_x, sparse_dim); + + TestElementWiseAddCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseSubtractCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseMultiplyCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); + TestElementWiseDivideCooGrad(dev_ctx_cpu, csr_x, csr_y, dense_dims); +} + +} // namespace tests +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py new file mode 100644 index 0000000000000000000000000000000000000000..61932cf4a7b0a8ad1c586cedec476b2e07445626 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from operator import __add__, __sub__, __mul__, __truediv__ + +import numpy as np +import paddle +from paddle.fluid.framework import _test_eager_guard + +op_list = [__add__, __sub__, __mul__, __truediv__] + + +def get_actual_res(x, y, op): + if op == __add__: + res = paddle.incubate.sparse.add(x, y) + elif op == __sub__: + res = paddle.incubate.sparse.subtract(x, y) + elif op == __mul__: + res = paddle.incubate.sparse.multiply(x, y) + elif op == __truediv__: + res = paddle.incubate.sparse.divide(x, y) + else: + raise ValueError("unsupported op") + return res + + +class TestSparseElementWiseAPI(unittest.TestCase): + """ + test paddle.sparse.add, subtract, multiply, divide + """ + + def setUp(self): + paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + np.random.seed(2022) + self.op_list = op_list + self.csr_shape = [128, 256] + self.coo_shape = [4, 8, 3, 5] + self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] + + def func_test_csr(self, op): + for dtype in self.support_dtypes: + x = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) + y = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) + + dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) + dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) + + s_dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) + s_dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) + csr_x = s_dense_x.to_sparse_csr() + csr_y = s_dense_y.to_sparse_csr() + + actual_res = get_actual_res(csr_x, csr_y, op) + actual_res.backward(actual_res) + + expect_res = op(dense_x, dense_y) + expect_res.backward(expect_res) + + self.assertTrue( + np.allclose(expect_res.numpy(), + actual_res.to_dense().numpy(), + equal_nan=True)) + if not (op == __truediv__ and dtype in ['int32', 'int64']): + self.assertTrue( + np.allclose(dense_x.grad.numpy(), + csr_x.grad.to_dense().numpy(), + equal_nan=True)) + self.assertTrue( + np.allclose(dense_y.grad.numpy(), + csr_y.grad.to_dense().numpy(), + equal_nan=True)) + + def func_test_coo(self, op): + for sparse_dim in range(len(self.coo_shape) - 1, len(self.coo_shape)): + for dtype in self.support_dtypes: + x = np.random.randint(-255, 255, + size=self.coo_shape).astype(dtype) + y = np.random.randint(-255, 255, + size=self.coo_shape).astype(dtype) + + dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) + dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) + + s_dense_x = paddle.to_tensor(x, + dtype=dtype, + stop_gradient=False) + s_dense_y = paddle.to_tensor(y, + dtype=dtype, + stop_gradient=False) + coo_x = s_dense_x.to_sparse_coo(sparse_dim) + coo_y = s_dense_y.to_sparse_coo(sparse_dim) + + actual_res = get_actual_res(coo_x, coo_y, op) + actual_res.backward(actual_res) + + expect_res = op(dense_x, dense_y) + expect_res.backward(expect_res) + + self.assertTrue( + np.allclose(expect_res.numpy(), + actual_res.to_dense().numpy(), + equal_nan=True)) + self.assertTrue( + np.allclose(dense_x.grad.numpy(), + coo_x.grad.to_dense().numpy(), + equal_nan=True)) + self.assertTrue( + np.allclose(dense_y.grad.numpy(), + coo_y.grad.to_dense().numpy(), + equal_nan=True)) + + def test_support_dtypes_csr(self): + paddle.device.set_device('cpu') + if paddle.device.get_device() == "cpu": + with _test_eager_guard(): + for op in op_list: + self.func_test_csr(op) + + def test_support_dtypes_coo(self): + paddle.device.set_device('cpu') + if paddle.device.get_device() == "cpu": + with _test_eager_guard(): + for op in op_list: + self.func_test_coo(op) + + +if __name__ == "__main__": + paddle.device.set_device('cpu') + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 5fe86995e1d307ea2c32fffb3d268cf64f44c156..05dd8b6d5657d7ee5fc2dc93e9ec5e0fc6a03c6b 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -22,6 +22,11 @@ from .unary import tanh from .binary import matmul from .binary import masked_matmul +from .math import add +from .math import divide +from .math import multiply +from .math import subtract + from . import nn __all__ = [ @@ -32,4 +37,8 @@ __all__ = [ 'tanh', 'matmul', 'masked_matmul', + 'add', + 'subtract', + 'multiply', + 'divide', ] diff --git a/python/paddle/incubate/sparse/math.py b/python/paddle/incubate/sparse/math.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a984c3ad5beab9bc40ad3b375d66f75d023df6 --- /dev/null +++ b/python/paddle/incubate/sparse/math.py @@ -0,0 +1,260 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +sparse math functions +""" +from __future__ import print_function + +from paddle import _C_ops, in_dynamic_mode, device, int32, int64 +from paddle.tensor import cast +from paddle.incubate.sparse import sparse_csr_tensor + + +def _cast_coo(x, dtype, name=None): + indices = x.indices() + values = cast(x.values(), dtype) + return _C_ops.final_state_sparse_create_sparse_coo_tensor( + values, indices, x.shape) + + +def _cast_csr(x, dtype, name=None): + crows = x.crows() + cols = x.cols() + values = cast(x.values(), dtype) + return sparse_csr_tensor(crows, cols, values, x.shape) + + +def _cast(x, dtype, name=None): + if x.is_sparse_coo(): + return _cast_coo(x, dtype, name) + return _cast_csr(x, dtype, name) + + +def add(x, y, name=None): + """ + Add two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse + type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical. + The equation is: + + .. math:: + out = x + y + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: the result tensor. + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + paddle.device.set_device("cpu") + + with _test_eager_guard(): + x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32') + y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32') + sparse_x = x.to_sparse_csr() + sparse_y = y.to_sparse_csr() + sparse_z = paddle.incubate.sparse.add(sparse_x, sparse_y) + print(sparse_z.to_dense()) + + # [[ 0., -1., 0., 0.], + # [ 0., 2., -6., 0.], + # [ 6., 8., 4., 8.]] + + """ + assert device.get_device( + ) == "cpu", "Currently, Sparse add only support CPU device." + assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" + assert x.is_sparse_csr() == y.is_sparse_csr( + ), f"Expect sparse tensor type to be same" + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_add(x, y) + else: + raise ValueError( + "Currently, sparse.add only support the input of SparseCooTensor or SparseCsrTensor" + ) + + +def subtract(x, y, name=None): + """ + Subtract two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse + type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical. + The equation is: + + .. math:: + out = x - y + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: the result tensor. + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + paddle.device.set_device("cpu") + + with _test_eager_guard(): + x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32') + y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32') + sparse_x = x.to_sparse_csr() + sparse_y = y.to_sparse_csr() + sparse_z = paddle.incubate.sparse.subtract(sparse_x, sparse_y) + print(sparse_z.to_dense()) + + # [[ 0., -1., 0., 4.], + # [ 0., -2., 0., 0.], + # [ 2., 2., -4., -8.]] + + """ + assert device.get_device( + ) == "cpu", "Currently, Sparse subtract only support CPU device." + assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" + assert x.is_sparse_csr() == y.is_sparse_csr( + ), f"Expect sparse tensor type to be same" + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_subtract(x, y) + else: + raise ValueError( + "Currently, sparse.subtract only support the input of SparseCooTensor or SparseCsrTensor" + ) + + +def multiply(x, y, name=None): + """ + Multiply two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse + type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical. + The equation is: + + .. math:: + out = x * y + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: the result tensor. + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + paddle.device.set_device("cpu") + + with _test_eager_guard(): + x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32') + y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32') + sparse_x = x.to_sparse_csr() + sparse_y = y.to_sparse_csr() + sparse_z = paddle.incubate.sparse.multiply(sparse_x, sparse_y) + print(sparse_z.to_dense()) + + # [[ 0., 0., 0., -4.], + # [ 0., 0., 9., 0.], + # [ 8., 15., 0., 0.]] + + """ + assert device.get_device( + ) == "cpu", "Currently, Sparse multiply only support CPU device." + assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" + assert x.is_sparse_csr() == y.is_sparse_csr( + ), f"Expect sparse tensor type to be same" + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_multiply(x, y) + else: + raise ValueError( + "Currently, sparse.multiply only support the input of SparseCooTensor or SparseCsrTensor" + ) + + +def divide(x, y, name=None): + """ + Divide two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse + type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical. + The equation is: + + .. math:: + out = x / y + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: the result tensor. + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + paddle.device.set_device("cpu") + + with _test_eager_guard(): + x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32') + y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32') + sparse_x = x.to_sparse_csr() + sparse_y = y.to_sparse_csr() + sparse_z = paddle.incubate.sparse.divide(sparse_x, sparse_y) + print(sparse_z.to_dense()) + + # [[ nan , -inf. , nan , -1. ], + # [ nan , 0. , 1. , nan ], + # [ 2. , 1.66666663, 0. , 0. ]] + + """ + assert device.get_device( + ) == "cpu", "Currently, Sparse divide only support CPU device." + assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode" + assert x.is_sparse_csr() == y.is_sparse_csr( + ), f"Expect sparse tensor type to be same" + + if x.dtype in [int32, int64]: + if x.is_sparse_coo() or x.is_sparse_csr(): + cx = _cast(x, 'float32') + cy = _cast(y, 'float32') + return _C_ops.final_state_sparse_divide(cx, cy) + else: + raise ValueError( + "Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor" + ) + else: + if x.is_sparse_coo() or x.is_sparse_csr(): + return _C_ops.final_state_sparse_divide(x, y) + else: + raise ValueError( + "Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor" + ) diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 09bb610e5cb1c6cc5c77d003f07e173d5c9f2fd3..e3b61bae1509b774e97fbccee4d3a50292e85e4d 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -1,3 +1,12 @@ +- api : add + args : (Tensor x, Tensor y) + output : Tensor(out) + kernel : + func : add_coo_coo{sparse_coo -> sparse_coo}, + add_csr_csr{sparse_csr -> sparse_csr} + layout : x + backward : add_grad + - api : conv3d args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) output : Tensor(out), Tensor(rulebook) @@ -28,6 +37,24 @@ invoke : to_sparse_coo_impl(x, sparse_dim) backward : dense_to_coo_grad +- api : divide + args : (Tensor x, Tensor y) + output : Tensor(out) + kernel : + func : divide_coo_coo{sparse_coo -> sparse_coo}, + divide_csr_csr{sparse_csr -> sparse_csr} + layout : x + backward : divide_grad + +- api : multiply + args : (Tensor x, Tensor y) + output : Tensor(out) + kernel : + func : multiply_coo_coo{sparse_coo -> sparse_coo}, + multiply_csr_csr{sparse_csr -> sparse_csr} + layout : x + backward : multiply_grad + - api : relu args : (Tensor x) output : Tensor(out) @@ -63,6 +90,15 @@ layout : x backward : sqrt_grad +- api : subtract + args : (Tensor x, Tensor y) + output : Tensor(out) + kernel : + func : subtract_coo_coo{sparse_coo -> sparse_coo}, + subtract_csr_csr{sparse_csr -> sparse_csr} + layout : x + backward : subtract_grad + - api : tanh args : (Tensor x) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index a4e83411bc84613758f0505339d2855933d4a8a0..7ddaba0f0a424004242024c9ff028d58217ae1aa 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -1,3 +1,11 @@ +- backward_api : add_grad + forward : add(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, + add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} + - backward_api : conv3d_grad forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) @@ -25,6 +33,14 @@ output : Tensor(x_grad) invoke : to_dense_impl(out_grad) +- backward_api : divide_grad + forward : divide(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, + divide_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} + - backward_api : masked_matmul_grad forward : masked_matmul(Tensor x, Tensor y, Tensor mask) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) @@ -39,6 +55,14 @@ kernel : func : csr_dense_matmul_grad{sparse_csr, dense, dense -> sparse_csr, dense} +- backward_api : multiply_grad + forward : multiply(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, + multiply_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} + - backward_api : relu_grad forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -74,6 +98,14 @@ kernel : func : sparse_coo_sqrt_grad {sparse_coo, sparse_coo -> sparse_coo} +- backward_api : subtract_grad + forward : subtract(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, + subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} + - backward_api : tanh_grad forward : tanh(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/python/setup.py.in b/python/setup.py.in index 17cb0ad6776aecdc76770092a4d9cc043606aca6..567a411d0980b8b0c3824fb192d4908f530b0651 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -281,6 +281,7 @@ packages=['paddle', 'paddle.incubate.tensor', 'paddle.incubate.multiprocessing', 'paddle.incubate.nn', + 'paddle.incubate.sparse', 'paddle.incubate.asp', 'paddle.incubate.passes', 'paddle.distribution',