// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" namespace phi { namespace sparse { #define DECLARE_SPARSE_UNARY_KERNEL(prefix) \ template \ void prefix##CooKernel( \ const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out); \ \ template \ void prefix##CsrKernel( \ const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out); #define DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(prefix, attr) \ template \ void prefix##CooKernel(const Context& dev_ctx, \ const SparseCooTensor& x, \ float attr, \ SparseCooTensor* out); \ \ template \ void prefix##CsrKernel(const Context& dev_ctx, \ const SparseCsrTensor& x, \ float attr, \ SparseCsrTensor* out); DECLARE_SPARSE_UNARY_KERNEL(Sin) DECLARE_SPARSE_UNARY_KERNEL(Tan) DECLARE_SPARSE_UNARY_KERNEL(Asin) DECLARE_SPARSE_UNARY_KERNEL(Atan) DECLARE_SPARSE_UNARY_KERNEL(Sinh) DECLARE_SPARSE_UNARY_KERNEL(Asinh) DECLARE_SPARSE_UNARY_KERNEL(Atanh) DECLARE_SPARSE_UNARY_KERNEL(Relu) DECLARE_SPARSE_UNARY_KERNEL(Tanh) DECLARE_SPARSE_UNARY_KERNEL(Square) DECLARE_SPARSE_UNARY_KERNEL(Sqrt) DECLARE_SPARSE_UNARY_KERNEL(Log1p) DECLARE_SPARSE_UNARY_KERNEL(Abs) DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) template void ScaleCooKernel(const Context& dev_ctx, const SparseCooTensor& x, float scale, float bias, bool bias_after_scale, SparseCooTensor* out); template void ScaleCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, float scale, float bias, bool bias_after_scale, SparseCsrTensor* out); template void DivCooScalarKernel(const Context& dev_ctx, const SparseCooTensor& x, float scalar, SparseCooTensor* out); template void DivCsrScalarKernel(const Context& dev_ctx, const SparseCsrTensor& x, float scalar, SparseCsrTensor* out); template void CastCooKernel(const Context& dev_ctx, const SparseCooTensor& x, DataType index_dtype, DataType value_dtype, SparseCooTensor* out); template void CastCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, DataType index_dtype, DataType value_dtype, SparseCsrTensor* out); template void TransposeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& perm, SparseCooTensor* out); template void TransposeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const std::vector& perm, SparseCsrTensor* out); template SparseCooTensor TransposeCoo(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& perm) { PADDLE_ENFORCE_EQ(x.sparse_dim(), perm.size(), phi::errors::InvalidArgument( "size of perm must be equal than the x.sparse_dim()")); SparseCooTensor coo; TransposeCooKernel(dev_ctx, x, perm, &coo); return coo; } template SparseCsrTensor TransposeCsr(const Context& dev_ctx, const SparseCsrTensor& x, const std::vector& perm) { PADDLE_ENFORCE_LE( 2, perm.size(), phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); PADDLE_ENFORCE_GE( 3, perm.size(), phi::errors::InvalidArgument("size of perm must be equal to 2 or 3")); SparseCsrTensor csr; TransposeCsrKernel(dev_ctx, x, perm, &csr); return csr; } template SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor coo; ReluCooKernel(dev_ctx, x, &coo); return coo; } template SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) { SparseCooTensor csr; ReluCsrKernel(dev_ctx, x, &csr); return csr; } template void ReshapeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const phi::IntArray& shape, SparseCooTensor* out); template void ReshapeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const phi::IntArray& shape, SparseCsrTensor* out); template SparseCooTensor ReshapeCoo(const Context& dev_ctx, const SparseCooTensor& x, const phi::IntArray& shape) { SparseCooTensor coo; ReshapeCooKernel(dev_ctx, x, shape, &coo); return coo; } template SparseCsrTensor ReshapeCsr(const Context& dev_ctx, const SparseCsrTensor& x, const phi::IntArray& shape) { PADDLE_ENFORCE_LE( 2, shape.size(), phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); PADDLE_ENFORCE_GE( 3, shape.size(), phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); SparseCsrTensor csr; ReshapeCsrKernel(dev_ctx, x, shape, &csr); return csr; } } // namespace sparse } // namespace phi