compound_functors.h 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_set>
#include <vector>

namespace paddle {
namespace operators {
namespace math {

C
chengduo 已提交
25
// Z = BinaryFunctor(X, UnaryFunctor(Y))
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
  BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
      : func1_(func1), func2_(func2) {}

  inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }

  inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
    return func1_(x, intermediat_out);
  }

  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); }

  BinaryFunctor func1_;
  UnaryFunctor func2_;
};

C
chengduo 已提交
43
// Z = UnaryFunctor(BinaryFunctor(X, Y))
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
  UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
      : func1_(func1), func2_(func2) {}

  inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }

  inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
    return func1_(intermediat_out);
  }

  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); }

  UnaryFunctor func1_;
  BinaryFunctor func2_;
};

C
chengduo 已提交
61
// Z = BinaryFunctor(X, UnaryFunctor(Y))
62 63 64 65 66 67
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
  BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
                              const UnaryFun &unary_fun)
      : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}

C
chengduo 已提交
68
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
69 70 71
    return dout * d_binary_fun_.Dx(x, unary_fun_(y));
  }

C
chengduo 已提交
72 73
  inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
                                         T dout) {
74 75 76 77 78 79 80 81
    return dout * d_binary_fun_.Dx(x, intermediate_out);
  }

 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
};

C
chengduo 已提交
82
// Z = BinaryFunctor(X, UnaryFunctor(Y))
83
template <typename T, typename DBinaryFun, typename UnaryFun,
C
chengduo 已提交
84
          typename DUnaryFun, bool InPlace>
85 86 87 88 89 90 91 92
struct BinaryCompoundGradDyFunctor {
  BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
                              const UnaryFun &unary_fun,
                              const DUnaryFun &d_unary_fun)
      : d_binary_fun_(d_binary_fun),
        unary_fun_(unary_fun),
        d_unary_fun_(d_unary_fun) {}

C
chengduo 已提交
93 94
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
95 96
  }

C
chengduo 已提交
97 98 99 100 101 102 103 104 105
  inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
                                         T dout) {
    if (InPlace) {
      return dout * d_binary_fun_.Dy(x, intermediate_out) *
             d_unary_fun_.UseOut(intermediate_out);
    } else {
      return dout * d_binary_fun_.Dy(x, intermediate_out) *
             d_unary_fun_.UseXAndOut(y, intermediate_out);
    }
106 107 108 109 110 111 112 113
  }

 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
  DUnaryFun d_unary_fun_;
};

C
chengduo 已提交
114
// Z = UnaryFunctor(BinaryFunctor(X, Y))
115
template <typename T, typename DUnaryFun, typename BinaryFun,
C
chengduo 已提交
116
          typename DBinaryFun, bool InPlace>
117 118 119 120 121 122 123 124
struct UnaryCompoundGradDxFunctor {
  UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
                             const BinaryFun &binary_fun,
                             const DBinaryFun &d_binary_fun)
      : d_unary_fun_(d_unary_fun),
        binary_fun_(binary_fun),
        d_binary_fun_(d_binary_fun) {}

C
chengduo 已提交
125
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
126
    T base;
C
chengduo 已提交
127 128
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
129
    } else {
C
chengduo 已提交
130
      base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
131 132 133 134
    }
    return base * d_binary_fun_.Dx(x, y);
  }

C
chengduo 已提交
135 136
  inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
                                         T dout) {
137
    T base;
C
chengduo 已提交
138 139
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
140
    } else {
C
chengduo 已提交
141
      base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
142 143 144 145 146 147 148 149 150 151
    }
    return base * d_binary_fun_.Dx(x, y);
  }

 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
  DBinaryFun d_binary_fun_;
};

C
chengduo 已提交
152
// Z = UnaryFunctor(BinaryFunctor(X, Y))
153
template <typename T, typename DUnaryFun, typename BinaryFun,
C
chengduo 已提交
154
          typename DBinaryFun, bool InPlace>
155 156 157 158 159 160 161 162
struct UnaryCompoundGradDyFunctor {
  UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
                             const BinaryFun &binary_fun,
                             const DBinaryFun &d_binary_fun)
      : d_unary_fun_(d_unary_fun),
        binary_fun_(binary_fun),
        d_binary_fun_(d_binary_fun) {}

C
chengduo 已提交
163
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
164
    T base;
C
chengduo 已提交
165 166
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
167
    } else {
C
chengduo 已提交
168
      base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
169 170 171 172
    }
    return base * d_binary_fun_.Dy(x, y);
  }

C
chengduo 已提交
173 174
  inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
                                         T dout) {
175
    T base;
C
chengduo 已提交
176 177
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
178
    } else {
C
chengduo 已提交
179
      base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
180 181 182 183 184 185 186 187 188 189
    }
    return base * d_binary_fun_.Dy(x, y);
  }

 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
  DBinaryFun d_binary_fun_;
};

C
chengduo 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDIntermedaiteOutFunctor {
  BinaryCompoundGradDIntermedaiteOutFunctor(const DBinaryFun &d_binary_fun,
                                            const UnaryFun &unary_fun)
      : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}

  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    return dout * d_binary_fun_.Dy(x, unary_fun_(y));
  }

  inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
                                         T dout) {
    return dout * d_binary_fun_.Dy(x, intermediate_out);
  }

 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
};

// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, bool InPlace>
struct UnaryCompoundGradDIntermediateFunctor {
  UnaryCompoundGradDIntermediateFunctor(const DUnaryFun &d_unary_fun,
                                        const BinaryFun &binary_fun)
      : d_unary_fun_(d_unary_fun), binary_fun_(binary_fun) {}

  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    if (InPlace) {
      return dout * d_unary_fun_.UseOut(out);
    } else {
      return dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
    }
  }

  inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
                                         T dout) {
    if (InPlace) {
      return dout * d_unary_fun_.UseOut(out);
    } else {
      return dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
    }
  }

 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
};

240 241 242
}  // namespace math
}  // namespace operators
}  // namespace paddle