math.cu 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/kernels/cuda/math.h"

17
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
C
Chen Weihang 已提交
18 19 20 21 22
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/hybird/eigen/sign.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
23 24 25 26 27 28 29 30 31

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

32
#include "paddle/fluid/platform/complex.h"
33 34
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
35
#include "paddle/pten/api/lib/utils/tensor_utils.h"
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/kernel_registry.h"

namespace pten {

/**
 * Util Functors
 */

template <typename T>
struct DivideFunctor {
  HOSTDEVICE explicit inline DivideFunctor(int n)
      : n_inv(static_cast<T>(1.0 / n)) {}

  HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

 private:
  T n_inv;
};

/**
 * Kernels
 */

template <typename T>
void Sign(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) {
  eigen::Sign<CUDAContext, T>(dev_ctx, x, out);
}

template <typename T>
66 67 68 69 70 71 72 73 74 75
void Mean(const CUDAContext& dev_ctx,
          const DenseTensor& x,
          const std::vector<int64_t>& dims,
          bool keep_dim,
          bool reduce_all,
          DataType in_dtype,
          DataType out_dtype,
          DenseTensor* out) {
  pten::Reduce<T, paddle::operators::CustomMean>(
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
76 77
}

78
// Create the definition of Add
Y
YuanRisheng 已提交
79
DEFINE_CUDA_ELEMENTWISE_OP(Add)
80 81 82 83 84 85
// Create the definition of Subtract
DEFINE_CUDA_ELEMENTWISE_OP(Subtract)
// Create the definition of Multiply
DEFINE_CUDA_ELEMENTWISE_OP(Multiply)
// Create the definition of Divide
DEFINE_CUDA_ELEMENTWISE_OP(Divide)
86

87 88 89 90 91 92 93 94 95 96 97 98 99
template <typename T>
void Sum(const CUDAContext& dev_ctx,
         const DenseTensor& x,
         const std::vector<int64_t>& dims,
         bool keep_dim,
         bool reduce_all,
         DataType in_dtype,
         DataType out_dtype,
         DenseTensor* out) {
  pten::Reduce<T, paddle::operators::CustomSum>(
      dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}

100 101 102
}  // namespace pten

using float16 = paddle::platform::float16;
103 104 105
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;

106 107 108
PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
}
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
109
PT_REGISTER_KERNEL(add,
110
                   CUDA,
111
                   ALL_LAYOUT,
112
                   pten::Add,
113 114 115 116 117 118 119
                   float,
                   double,
                   int,
                   int64_t,
                   float16,
                   complex64,
                   complex128) {}
120
PT_REGISTER_KERNEL(subtract,
121
                   CUDA,
122
                   ALL_LAYOUT,
123
                   pten::Subtract,
124 125 126 127 128 129 130
                   float,
                   double,
                   int,
                   int64_t,
                   float16,
                   complex64,
                   complex128) {}
131
PT_REGISTER_KERNEL(divide,
132
                   CUDA,
133
                   ALL_LAYOUT,
134
                   pten::Divide,
135 136 137 138 139 140 141
                   float,
                   double,
                   int,
                   int64_t,
                   float16,
                   complex64,
                   complex128) {}
142
PT_REGISTER_KERNEL(multiply,
Y
YuanRisheng 已提交
143
                   CUDA,
144
                   ALL_LAYOUT,
145
                   pten::Multiply,
Y
YuanRisheng 已提交
146 147 148 149 150 151 152 153
                   float,
                   double,
                   int,
                   int64_t,
                   bool,
                   float16,
                   complex64,
                   complex128) {}
154
PT_REGISTER_KERNEL(sum,
155
                   CUDA,
156
                   ALL_LAYOUT,
157 158 159 160 161 162 163 164 165 166 167
                   pten::Sum,
                   bool,
                   float,
                   double,
                   float16,
                   int,
                   int64_t,
                   complex64,
                   complex128) {
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}