functor_primitives.h 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

S
sneaxiy 已提交
17 18 19
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
20
#include "paddle/pten/kernels/funcs/eigen/extensions.h"
21

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
namespace paddle {
namespace operators {
namespace kernel_primitives {
namespace details {

static __device__ __forceinline__ platform::float16 Exp(platform::float16 x) {
  return ::Eigen::numext::exp(x);
}

static __device__ __forceinline__ float Exp(float x) { return expf(x); }

static __device__ __forceinline__ double Exp(double x) { return exp(x); }

static __device__ __forceinline__ platform::float16 Log(platform::float16 x) {
  return ::Eigen::numext::log(x);
}

static __device__ __forceinline__ float Log(float x) { return logf(x); }

static __device__ __forceinline__ double Log(double x) { return log(x); }

}  // namespace details

/******************************** Unary Functor *******************************/

/**
 * @brief Default unary exp functor
 */
template <typename Tx, typename Ty = Tx>
struct ExpFunctor {
  HOSTDEVICE inline ExpFunctor() {}

  HOSTDEVICE explicit inline ExpFunctor(int n) {}

56
  HOSTDEVICE inline Ty operator()(const Tx x) const {
57 58 59 60 61 62 63 64 65 66 67 68 69
    return static_cast<Ty>(details::Exp(x));
  }
};

/**
 * @brief Default unary identity functor
 */
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
  HOSTDEVICE inline IdentityFunctor() {}

  HOSTDEVICE explicit inline IdentityFunctor(int n) {}

70
  HOSTDEVICE inline Ty operator()(const Tx x) const {
71 72 73 74 75 76 77 78 79
    return static_cast<Ty>(x);
  }
};

/**
 * @brief Default unary div functor. Divide by a constant
 */
template <typename Tx, typename Ty = Tx>
struct DivideFunctor {
S
sneaxiy 已提交
80 81 82 83 84
 private:
  using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;

 public:
  HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }
85

S
sneaxiy 已提交
86
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}
87

88
  HOSTDEVICE inline Ty operator()(const Tx x) const {
S
sneaxiy 已提交
89
    return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
90 91 92
  }

 private:
S
sneaxiy 已提交
93
  MPType n_inv;
94 95
};

96 97 98 99 100 101 102 103 104
/**
 * @brief Default inverse functor
 */
template <typename Tx, typename Ty = Tx>
struct InverseFunctor {
  HOSTDEVICE inline InverseFunctor() {}

  HOSTDEVICE explicit inline InverseFunctor(int n) {}

105
  HOSTDEVICE inline Ty operator()(const Tx x) const {
106 107 108 109
    return static_cast<Ty>(-x);
  }
};

110 111 112 113 114 115 116 117 118
/**
 * @brief Default unary square functor
 */
template <typename Tx, typename Ty = Tx>
struct SquareFunctor {
  HOSTDEVICE inline SquareFunctor() {}

  HOSTDEVICE explicit inline SquareFunctor(int n) {}

119
  HOSTDEVICE inline Ty operator()(const Tx x) const {
120 121 122 123 124 125 126 127 128 129 130 131 132
    return static_cast<Ty>(x) * static_cast<Ty>(x);
  }
};

/****************************** Binary Functor ********************************/

/**
 * @brief Default binary min functor
 */
template <typename T>
struct MinFunctor {
  inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); }

133
  __device__ __forceinline__ T operator()(const T a, const T b) const {
134 135 136 137 138 139 140 141 142 143 144 145 146
    return (b < a) ? b : a;
  }
};

/**
 * @brief Default binary max functor
 */
template <typename T>
struct MaxFunctor {
  inline T initial() {
    return static_cast<T>(std::numeric_limits<T>::lowest());
  }

147
  __device__ __forceinline__ T operator()(const T a, const T b) const {
148 149 150 151 152 153 154 155 156 157 158
    return (b > a) ? b : a;
  }
};

/**
 * @brief Default binary add functor
 */
template <typename T>
struct AddFunctor {
  inline T initial() { return static_cast<T>(0.0f); }

159
  __device__ __forceinline__ T operator()(const T a, const T b) const {
160 161 162 163 164 165 166 167 168 169 170
    return b + a;
  }
};

/**
 * @brief Default binary add functor
 */
template <typename T>
struct MulFunctor {
  inline T initial() { return static_cast<T>(1.0f); }

171
  __device__ __forceinline__ T operator()(const T a, const T b) const {
172 173 174 175 176 177 178 179 180 181 182
    return b * a;
  }
};

/**
 * @brief Default binary logic or functor
 */
template <typename T>
struct LogicalOrFunctor {
  inline T initial() { return static_cast<T>(false); }

183
  __device__ __forceinline__ T operator()(const T a, const T b) const {
184 185 186 187 188 189 190 191 192 193 194
    return b || a;
  }
};

/**
 * @brief Default binary logic and functor
 */
template <typename T>
struct LogicalAndFunctor {
  inline T initial() { return static_cast<T>(true); }

195
  __device__ __forceinline__ T operator()(const T a, const T b) const {
196 197 198 199 200 201 202 203 204 205 206
    return b && a;
  }
};

/**
 * @brief Default binary sub functor
 */
template <typename T>
struct SubFunctor {
  inline T initial() { return static_cast<T>(0.0f); }

207
  inline HOSTDEVICE T operator()(const T a, const T b) const { return a - b; }
208 209 210 211 212 213 214 215 216
};

/**
 * @brief Default binary div functor
 */
template <typename T, typename Enable = void>
struct DivFunctor {
  inline T initial() { return static_cast<T>(1.0f); }

217
  inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; }
218 219 220 221 222 223 224
};

template <typename T>
struct DivFunctor<T,
                  typename std::enable_if<std::is_integral<T>::value>::type> {
  inline T initial() { return static_cast<T>(1.0f); }

225
  inline HOSTDEVICE T operator()(const T a, const T b) const {
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
    // For int32/int64, need to check whether the divison is zero.
    PADDLE_ENFORCE_NE(b, 0,
                      platform::errors::InvalidArgument(
                          "Integer division by zero encountered "
                          "in (floor) divide. Please check the input value."));
    return a / b;
  }
};

/**
 * @brief Default binary floor divide functor
 */
template <typename T>
struct FloorDivFunctor {
  inline T initial() { return static_cast<T>(1.0f); }

242
  inline HOSTDEVICE T operator()(const T a, const T b) const {
243 244 245 246 247 248 249 250 251 252 253
    PADDLE_ENFORCE_NE(b, 0,
                      platform::errors::InvalidArgument(
                          "Integer division by zero encountered "
                          "in (floor) divide. Please check the input value."));
    return static_cast<T>(std::trunc(a / b));
  }
};

}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle