elementwise_functor.h 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

14
#include "paddle/fluid/platform/complex.h"
15 16
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
17 18 19 20 21

namespace paddle {
namespace operators {

// Define the binary functors used in elementwise ops.
22
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
23 24 25

// Add
template <typename T>
26
using AddFunctor = phi::funcs::AddFunctor<T>;
27

28
template <typename T>
29
using InverseAddFunctor = phi::funcs::InverseAddFunctor<T>;
30 31 32

// Subtract
template <typename T>
33
using SubFunctor = phi::funcs::SubtractFunctor<T>;
34

35
template <typename T>
36
using InverseSubFunctor = phi::funcs::InverseSubtractFunctor<T>;
37 38 39

// Multiply
template <typename T>
40
using MulFunctor = phi::funcs::MultiplyFunctor<T>;
41

42
template <typename T>
43
using InverseMulFunctor = phi::funcs::InverseMultiplyFunctor<T>;
44 45 46

// Divide
template <typename T>
47
using DivFunctor = phi::funcs::DivideFunctor<T>;
48

49
template <typename T>
50
using InverseDivFunctor = phi::funcs::InverseDivideFunctor<T>;
51 52 53 54

// Floor Divide
template <typename T>
struct FloorDivFunctor {
55
  inline HOSTDEVICE T operator()(const T a, const T b) const {
56 57 58 59 60 61 62
    PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
    return static_cast<T>(std::trunc(a / b));
  }
};

template <typename T>
struct InverseFloorDivFunctor {
63
  inline HOSTDEVICE T operator()(const T a, const T b) const {
64 65 66 67 68 69 70 71 72 73
    PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO);
    return static_cast<T>(std::trunc(b / a));
  }
};

#undef DIV_ERROR_INFO

// Maximum
template <typename T>
struct MaxFunctor {
74
  inline HOSTDEVICE T operator()(const T a, const T b) const {
75 76 77 78 79 80 81
    return a > b ? a : b;
  }
};

// Minmum
template <typename T>
struct MinFunctor {
82
  inline HOSTDEVICE T operator()(const T a, const T b) const {
83 84 85 86
    return a < b ? a : b;
  }
};

87 88 89
template <typename T>
using Complex = paddle::platform::complex<T>;

90 91
template <typename T>
struct MinGradXFunctor {
92
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
93 94 95 96 97
    return dout * static_cast<T>(x < y);
  }
};
template <typename T>
struct MinGradYFunctor {
98
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
99 100 101 102 103 104
    return dout * static_cast<T>(x >= y);
  }
};

template <typename InT, typename OutT>
struct MinGradXYFunctor {
105 106 107
  inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
                                                   const InT dout) {
    phi::Array<OutT, 2> outs;
108 109 110 111 112 113 114 115
    // dx = dout * (x < y)
    outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
    // dy = dout * (x >= y)
    outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
    return outs;
  }
};

116 117 118
// Ternary compare
template <typename T>
struct MaxGradXFunctor {
119
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
120 121 122 123 124
    return dout * static_cast<T>(x > y);
  }
};
template <typename T>
struct MaxGradYFunctor {
125
  inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
126 127 128 129 130 131
    return dout * static_cast<T>(x <= y);
  }
};

template <typename InT, typename OutT>
struct MaxGradXYFunctor {
132 133 134
  inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
                                                   const InT dout) {
    phi::Array<OutT, 2> outs;
135 136 137 138 139 140 141 142
    // dx = dout * (x > y)
    outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
    // dy = dout * (x <= y)
    outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
    return outs;
  }
};

143 144
}  // namespace operators
}  // namespace paddle