unary_kernel.h 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"

20 21 22 23
namespace phi {
namespace sparse {

#define DECLARE_SPARSE_UNARY_KERNEL(prefix)                                    \
24
  template <typename T, typename Context>                                      \
25
  void prefix##CooKernel(                                                      \
26 27 28
      const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out); \
                                                                               \
  template <typename T, typename Context>                                      \
29
  void prefix##CsrKernel(                                                      \
30 31
      const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out);

32 33 34 35 36 37 38 39 40 41 42 43
#define DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(prefix, attr) \
  template <typename T, typename Context>                       \
  void prefix##CooKernel(const Context& dev_ctx,                \
                         const SparseCooTensor& x,              \
                         float attr,                            \
                         SparseCooTensor* out);                 \
                                                                \
  template <typename T, typename Context>                       \
  void prefix##CsrKernel(const Context& dev_ctx,                \
                         const SparseCsrTensor& x,              \
                         float attr,                            \
                         SparseCsrTensor* out);
44

45 46 47 48 49 50 51
DECLARE_SPARSE_UNARY_KERNEL(Sin)
DECLARE_SPARSE_UNARY_KERNEL(Tan)
DECLARE_SPARSE_UNARY_KERNEL(Asin)
DECLARE_SPARSE_UNARY_KERNEL(Atan)
DECLARE_SPARSE_UNARY_KERNEL(Sinh)
DECLARE_SPARSE_UNARY_KERNEL(Asinh)
DECLARE_SPARSE_UNARY_KERNEL(Atanh)
52
DECLARE_SPARSE_UNARY_KERNEL(Relu)
53 54
DECLARE_SPARSE_UNARY_KERNEL(Tanh)
DECLARE_SPARSE_UNARY_KERNEL(Square)
55
DECLARE_SPARSE_UNARY_KERNEL(Sqrt)
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
DECLARE_SPARSE_UNARY_KERNEL(Log1p)
DECLARE_SPARSE_UNARY_KERNEL(Abs)
DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)

template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    float scale,
                    float bias,
                    bool bias_after_scale,
                    SparseCooTensor* out);

template <typename T, typename Context>
void ScaleCsrKernel(const Context& dev_ctx,
                    const SparseCsrTensor& x,
                    float scale,
                    float bias,
                    bool bias_after_scale,
                    SparseCsrTensor* out);
75 76

template <typename T, typename Context>
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
void DivCooScalarKernel(const Context& dev_ctx,
                        const SparseCooTensor& x,
                        float scalar,
                        SparseCooTensor* out);

template <typename T, typename Context>
void DivCsrScalarKernel(const Context& dev_ctx,
                        const SparseCsrTensor& x,
                        float scalar,
                        SparseCsrTensor* out);

template <typename T, typename Context>
void CastCooKernel(const Context& dev_ctx,
                   const SparseCooTensor& x,
                   DataType index_dtype,
                   DataType value_dtype,
                   SparseCooTensor* out);

template <typename T, typename Context>
void CastCsrKernel(const Context& dev_ctx,
                   const SparseCsrTensor& x,
                   DataType index_dtype,
                   DataType value_dtype,
                   SparseCsrTensor* out);

template <typename T, typename Context>
SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) {
  SparseCooTensor coo;
  ReluCooKernel<T, Context>(dev_ctx, x, &coo);
106 107 108
  return coo;
}

109 110 111 112 113 114 115
template <typename T, typename Context>
SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
  SparseCooTensor csr;
  ReluCsrKernel<T, Context>(dev_ctx, x, &csr);
  return csr;
}

116 117
}  // namespace sparse
}  // namespace phi