compound_functors.h 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_set>
#include <vector>

21 22
namespace pten {
namespace funcs {
23

C
chengduo 已提交
24
// Z = BinaryFunctor(X, UnaryFunctor(Y))
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
  BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
      : func1_(func1), func2_(func2) {}

  inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }

  inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
    return func1_(x, intermediat_out);
  }

  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); }

  BinaryFunctor func1_;
  UnaryFunctor func2_;
};

C
chengduo 已提交
42
// Z = UnaryFunctor(BinaryFunctor(X, Y))
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
  UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
      : func1_(func1), func2_(func2) {}

  inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }

  inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
    return func1_(intermediat_out);
  }

  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); }

  UnaryFunctor func1_;
  BinaryFunctor func2_;
};

C
chengduo 已提交
60
// Z = BinaryFunctor(X, UnaryFunctor(Y))
61 62 63 64 65 66
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
  BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
                              const UnaryFun &unary_fun)
      : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}

C
chengduo 已提交
67
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
68 69 70
    return dout * d_binary_fun_.Dx(x, unary_fun_(y));
  }

71 72
  inline HOSTDEVICE T
  UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
73 74 75
    return dout * d_binary_fun_.Dx(x, intermediate_out);
  }

76 77
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }

78 79 80 81 82
 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
};

C
chengduo 已提交
83
// Z = BinaryFunctor(X, UnaryFunctor(Y))
84 85 86 87 88
template <typename T,
          typename DBinaryFun,
          typename UnaryFun,
          typename DUnaryFun,
          bool InPlace>
89 90 91 92 93 94 95 96
struct BinaryCompoundGradDyFunctor {
  BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
                              const UnaryFun &unary_fun,
                              const DUnaryFun &d_unary_fun)
      : d_binary_fun_(d_binary_fun),
        unary_fun_(unary_fun),
        d_unary_fun_(d_unary_fun) {}

C
chengduo 已提交
97 98
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
99 100
  }

101 102
  inline HOSTDEVICE T
  UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
C
chengduo 已提交
103 104 105 106 107 108 109
    if (InPlace) {
      return dout * d_binary_fun_.Dy(x, intermediate_out) *
             d_unary_fun_.UseOut(intermediate_out);
    } else {
      return dout * d_binary_fun_.Dy(x, intermediate_out) *
             d_unary_fun_.UseXAndOut(y, intermediate_out);
    }
110 111
  }

112 113
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }

114 115 116 117 118 119
 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
  DUnaryFun d_unary_fun_;
};

C
chengduo 已提交
120
// Z = UnaryFunctor(BinaryFunctor(X, Y))
121 122 123 124 125
template <typename T,
          typename DUnaryFun,
          typename BinaryFun,
          typename DBinaryFun,
          bool InPlace>
126 127 128 129 130 131 132 133
struct UnaryCompoundGradDxFunctor {
  UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
                             const BinaryFun &binary_fun,
                             const DBinaryFun &d_binary_fun)
      : d_unary_fun_(d_unary_fun),
        binary_fun_(binary_fun),
        d_binary_fun_(d_binary_fun) {}

C
chengduo 已提交
134
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
135
    T base;
C
chengduo 已提交
136 137
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
138
    } else {
C
chengduo 已提交
139
      base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
140 141 142 143
    }
    return base * d_binary_fun_.Dx(x, y);
  }

144 145
  inline HOSTDEVICE T
  UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
146
    T base;
C
chengduo 已提交
147 148
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
149
    } else {
C
chengduo 已提交
150
      base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
151 152 153 154
    }
    return base * d_binary_fun_.Dx(x, y);
  }

155 156
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }

157 158 159 160 161 162
 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
  DBinaryFun d_binary_fun_;
};

C
chengduo 已提交
163
// Z = UnaryFunctor(BinaryFunctor(X, Y))
164 165 166 167 168
template <typename T,
          typename DUnaryFun,
          typename BinaryFun,
          typename DBinaryFun,
          bool InPlace>
169 170 171 172 173 174 175 176
struct UnaryCompoundGradDyFunctor {
  UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
                             const BinaryFun &binary_fun,
                             const DBinaryFun &d_binary_fun)
      : d_unary_fun_(d_unary_fun),
        binary_fun_(binary_fun),
        d_binary_fun_(d_binary_fun) {}

C
chengduo 已提交
177
  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
178
    T base;
C
chengduo 已提交
179 180
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
181
    } else {
C
chengduo 已提交
182
      base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
183 184 185 186
    }
    return base * d_binary_fun_.Dy(x, y);
  }

187 188
  inline HOSTDEVICE T
  UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
189
    T base;
C
chengduo 已提交
190 191
    if (InPlace) {
      base = dout * d_unary_fun_.UseOut(out);
192
    } else {
C
chengduo 已提交
193
      base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
194 195 196 197
    }
    return base * d_binary_fun_.Dy(x, y);
  }

198 199
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }

200 201 202 203 204 205
 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
  DBinaryFun d_binary_fun_;
};

C
chengduo 已提交
206 207 208 209 210 211 212 213 214 215 216
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDIntermedaiteOutFunctor {
  BinaryCompoundGradDIntermedaiteOutFunctor(const DBinaryFun &d_binary_fun,
                                            const UnaryFun &unary_fun)
      : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}

  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    return dout * d_binary_fun_.Dy(x, unary_fun_(y));
  }

217 218 219
  inline HOSTDEVICE T UseIntermediateOut(T x,
                                         T intermediate_out,
                                         T out,
C
chengduo 已提交
220 221 222 223
                                         T dout) {
    return dout * d_binary_fun_.Dy(x, intermediate_out);
  }

224 225
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }

C
chengduo 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
 private:
  DBinaryFun d_binary_fun_;
  UnaryFun unary_fun_;
};

// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, bool InPlace>
struct UnaryCompoundGradDIntermediateFunctor {
  UnaryCompoundGradDIntermediateFunctor(const DUnaryFun &d_unary_fun,
                                        const BinaryFun &binary_fun)
      : d_unary_fun_(d_unary_fun), binary_fun_(binary_fun) {}

  inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) {
    if (InPlace) {
      return dout * d_unary_fun_.UseOut(out);
    } else {
      return dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out);
    }
  }

246 247 248
  inline HOSTDEVICE T UseIntermediateOut(T x,
                                         T intermediate_out,
                                         T out,
C
chengduo 已提交
249 250 251 252 253 254 255 256
                                         T dout) {
    if (InPlace) {
      return dout * d_unary_fun_.UseOut(out);
    } else {
      return dout * d_unary_fun_.UseXAndOut(intermediate_out, out);
    }
  }

257 258
  inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }

C
chengduo 已提交
259 260 261 262 263
 private:
  DUnaryFun d_unary_fun_;
  BinaryFun binary_fun_;
};

264 265
}  // namespace funcs
}  // namespace pten