unary_kernel.cu 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T>
struct DivScalarFunctor {
  T value_;

  explicit DivScalarFunctor(T value) : value_(value) {}

  __device__ __forceinline__ T operator()(const T x) const {
    return x / value_;
  }
};

template <typename T, typename Context>
void DivCooScalarKernel(const Context& dev_ctx,
                        const SparseCooTensor& x,
                        float scalar,
                        SparseCooTensor* out) {
  EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);

  std::vector<const DenseTensor*> ins = {&(x.non_zero_elements())};
  std::vector<DenseTensor*> outs = {out->mutable_non_zero_elements()};
  DivScalarFunctor<T> func(static_cast<T>(scalar));
  funcs::ElementwiseKernel<T, DivScalarFunctor<T>>(dev_ctx, ins, &outs, func);
}

template <typename T, typename Context>
void DivCsrScalarKernel(const Context& dev_ctx,
                        const SparseCsrTensor& x,
                        float scalar,
                        SparseCsrTensor* out) {
  EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);

  std::vector<const DenseTensor*> ins = {&(x.non_zero_elements())};
  std::vector<DenseTensor*> outs = {out->mutable_non_zero_elements()};
  DivScalarFunctor<T> func(static_cast<T>(scalar));
  funcs::ElementwiseKernel<T, DivScalarFunctor<T>>(dev_ctx, ins, &outs, func);
}

}  // namespace sparse
}  // namespace phi

#define PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(name, prefix)          \
  PD_REGISTER_KERNEL(name##_coo,                                   \
                     GPU,                                          \
                     ALL_LAYOUT,                                   \
                     phi::sparse::prefix##CooKernel,               \
                     float,                                        \
                     double) {                                     \
    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
  }                                                                \
                                                                   \
  PD_REGISTER_KERNEL(name##_csr,                                   \
                     GPU,                                          \
                     ALL_LAYOUT,                                   \
                     phi::sparse::prefix##CsrKernel,               \
                     float,                                        \
                     double) {                                     \
    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
  }

PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sin, Sin)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(tan, Tan)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(asin, Asin)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(atan, Atan)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sinh, Sinh)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(tanh, Tanh)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(asinh, Asinh)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(atanh, Atanh)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sqrt, Sqrt)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(square, Square)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)

PD_REGISTER_KERNEL(divide_coo_scalar,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::DivCooScalarKernel,
                   float,
                   double) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(divide_csr_scalar,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::DivCsrScalarKernel,
                   float,
                   double) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(cast_coo,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::CastCooKernel,
                   float,
                   double,
                   int8_t,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool) {}

PD_REGISTER_KERNEL(cast_csr,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::CastCsrKernel,
                   float,
                   double,
                   int8_t,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool) {}