elementwise_functor.h 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/fluid/platform/complex.h"
18 19
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
20 21 22 23 24

namespace paddle {
namespace operators {

// Define the binary functors used in elementwise ops.
25
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
26 27 28

// Add
template <typename T>
29
using AddFunctor = phi::funcs::AddFunctor<T>;
30

31
template <typename T>
32
using InverseAddFunctor = phi::funcs::InverseAddFunctor<T>;
33 34 35

// Subtract
template <typename T>
36
using SubFunctor = phi::funcs::SubtractFunctor<T>;
37

38
template <typename T>
39
using InverseSubFunctor = phi::funcs::InverseSubtractFunctor<T>;
40 41 42

// Multiply
template <typename T>
43
using MulFunctor = phi::funcs::MultiplyFunctor<T>;
44

45
template <typename T>
46
using InverseMulFunctor = phi::funcs::InverseMultiplyFunctor<T>;
47 48 49

// Divide
template <typename T>
50
using DivFunctor = phi::funcs::DivideFunctor<T>;
51

52
template <typename T>
53
using InverseDivFunctor = phi::funcs::InverseDivideFunctor<T>;
54 55 56 57

// Floor Divide
template <typename T>
struct FloorDivFunctor {
58
  inline HOSTDEVICE T operator()(const T a, const T b) const {
59 60 61 62 63 64 65
    PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
    return static_cast<T>(std::trunc(a / b));
  }
};

template <typename T>
struct InverseFloorDivFunctor {
66
  inline HOSTDEVICE T operator()(const T a, const T b) const {
67 68 69 70 71 72 73 74 75 76
    PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO);
    return static_cast<T>(std::trunc(b / a));
  }
};

#undef DIV_ERROR_INFO

// Maximum
template <typename T>
struct MaxFunctor {
77
  inline HOSTDEVICE T operator()(const T a, const T b) const {
78 79 80 81 82 83 84
    return a > b ? a : b;
  }
};

// Minmum
template <typename T>
struct MinFunctor {
85
  inline HOSTDEVICE T operator()(const T a, const T b) const {
86 87 88 89
    return a < b ? a : b;
  }
};

90 91 92
template <typename T>
using Complex = paddle::platform::complex<T>;

L
LJQ❤️ 已提交
93 94 95
// Fmax
template <typename T>
struct FMaxFunctor {
96
  inline HOSTDEVICE T operator()(const T a, const T b) const {
L
LJQ❤️ 已提交
97 98 99 100 101 102 103
    return std::fmax(a, b);
  }
};

template <>
struct FMaxFunctor<paddle::platform::float16> {
  inline HOSTDEVICE paddle::platform::float16 operator()(
104 105
      const paddle::platform::float16 a,
      const paddle::platform::float16 b) const {
L
LJQ❤️ 已提交
106 107 108 109 110 111 112
    float float_a = static_cast<float>(a);
    float float_b = static_cast<float>(b);
    auto result = std::fmax(float_a, float_b);
    return static_cast<paddle::platform::float16>(result);
  }
};

113 114
template <>
struct FMaxFunctor<int> {
115
  inline HOSTDEVICE int operator()(const int a, const int b) const {
116 117 118 119 120 121 122 123 124
    float float_a = static_cast<float>(a);
    float float_b = static_cast<float>(b);
    auto result = std::fmax(float_a, float_b);
    return std::lrint(result);
  }
};

template <>
struct FMaxFunctor<int64_t> {
125
  inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
126 127 128 129 130 131 132
    double double_a = static_cast<double>(a);
    double double_b = static_cast<double>(b);
    auto result = std::fmax(double_a, double_b);
    return std::llrint(result);
  }
};

L
LJQ❤️ 已提交
133 134 135
// Fmin
template <typename T>
struct FMinFunctor {
136
  inline HOSTDEVICE T operator()(const T a, const T b) const {
L
LJQ❤️ 已提交
137 138 139 140 141 142 143
    return std::fmin(a, b);
  }
};

template <>
struct FMinFunctor<paddle::platform::float16> {
  inline HOSTDEVICE paddle::platform::float16 operator()(
144 145
      const paddle::platform::float16 a,
      const paddle::platform::float16 b) const {
L
LJQ❤️ 已提交
146 147 148 149 150 151 152
    float float_a = static_cast<float>(a);
    float float_b = static_cast<float>(b);
    auto result = std::fmin(float_a, float_b);
    return static_cast<paddle::platform::float16>(result);
  }
};

153 154
template <>
struct FMinFunctor<int> {
155
  inline HOSTDEVICE int operator()(const int a, const int b) const {
156 157 158 159 160 161 162 163 164
    float float_a = static_cast<float>(a);
    float float_b = static_cast<float>(b);
    auto result = std::fmin(float_a, float_b);
    return std::lrint(result);
  }
};

template <>
struct FMinFunctor<int64_t> {
165
  inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
166 167 168 169 170 171 172
    double double_a = static_cast<double>(a);
    double double_b = static_cast<double>(b);
    auto result = std::fmin(double_a, double_b);
    return std::llrint(result);
  }
};

173 174
template <typename T>
struct MinGradXFunctor {
175
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
176 177 178 179 180
    return dout * static_cast<T>(x < y);
  }
};
template <typename T>
struct MinGradYFunctor {
181
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
182 183 184 185 186 187
    return dout * static_cast<T>(x >= y);
  }
};

template <typename InT, typename OutT>
struct MinGradXYFunctor {
188 189 190
  inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
                                                   const InT dout) {
    phi::Array<OutT, 2> outs;
191 192 193 194 195 196 197 198
    // dx = dout * (x < y)
    outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
    // dy = dout * (x >= y)
    outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
    return outs;
  }
};

199 200
template <typename T>
struct MulGradFunctor {
201
  inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
202 203 204
};
template <typename T>
struct MulGradFunctor<Complex<T>> {
205 206
  inline HOSTDEVICE Complex<T> operator()(const Complex<T> a,
                                          const Complex<T> b) const {
207 208 209 210 211 212 213
    Complex<T> b_conj(b.real, -b.imag);
    return a * b_conj;
  }
};

template <typename InT, typename OutT>
struct MulGradXYFunctor {
214 215 216
  inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT a, const InT b,
                                                   const InT c) {
    phi::Array<OutT, 2> outs;
217 218 219 220 221 222 223 224 225 226
    // dx = dout * y
    outs[0] = a * b;
    // dy = dout * x
    outs[1] = a * c;
    return outs;
  }
};

template <typename InT, typename OutT>
struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
227
  inline HOSTDEVICE phi::Array<Complex<OutT>, 2> operator()(
228
      const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
229
    phi::Array<Complex<OutT>, 2> outs;
230 231 232 233 234 235 236 237 238 239
    // dx = dout * y
    Complex<InT> b_conj(b.real, -b.imag);
    outs[0] = a * b_conj;
    // dy = dout * x
    Complex<InT> c_conj(c.real, -c.imag);
    outs[1] = a * c_conj;
    return outs;
  }
};

240 241 242
// Ternary compare
template <typename T>
struct MaxGradXFunctor {
243
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
244 245 246 247 248
    return dout * static_cast<T>(x > y);
  }
};
template <typename T>
struct MaxGradYFunctor {
249
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
250 251 252 253 254 255
    return dout * static_cast<T>(x <= y);
  }
};

template <typename InT, typename OutT>
struct MaxGradXYFunctor {
256 257 258
  inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
                                                   const InT dout) {
    phi::Array<OutT, 2> outs;
259 260 261 262 263 264 265 266
    // dx = dout * (x > y)
    outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
    // dy = dout * (x <= y)
    outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
    return outs;
  }
};

267 268
}  // namespace operators
}  // namespace paddle