activation_op.h 42.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
11 12

#pragma once
D
dzhwinter 已提交
13
#include <glog/logging.h>
Y
Yihua Xu 已提交
14
#include <algorithm>
15
#include <memory>
D
dzhwinter 已提交
16 17
#include <string>
#include <unordered_set>
18 19
#include <utility>
#include <vector>
20

C
Clementine 已提交
21 22 23 24 25
#include <cmath>
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

Y
Yi Wang 已提交
26 27 28
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
Y
Yihua Xu 已提交
29
#include "paddle/fluid/operators/math/blas.h"
30
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
31

32 33 34 35
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

Q
qijun 已提交
36 37 38
namespace paddle {
namespace operators {

Z
zhoukunsheng 已提交
39

40 41 42 43 44 45 46 47 48 49 50 51 52
enum ActBwdOpFwdDeps {
  kNoDeps = 0x00,  // Do not need any forward input/output
  kDepX = 0x01,    // Only need forward input X
  kDepOut = 0x02,  // Only need forward output Out

  // Never add kDepXOut, because Out can be always calculated
  // by forward input X in backward part.
  // FIXME(zjl): but in MKLDNN abs, X and Out are all needed...
  // Developers should not rely on this enum value!
  kDepXOut = 0x03
};

std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet();
D
dzhwinter 已提交
53

54
static bool IsInplace(const std::string& op) {
55 56
  static auto InplaceOpSet = GetInplaceOpSet();
  bool inplace = InplaceOpSet->count(op);
57 58 59 60 61
  // for op_grad
  const int kGradSuffixLen = 4;
  if (op.size() > kGradSuffixLen &&
      op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) {
    inplace =
62
        InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1)));
63 64 65 66
  }
  return inplace;
}

C
chengduo 已提交
67 68 69 70 71 72
/* The following operator can be used to process SelectedRows, because the
 * output of those operator for zero is zero too.
 */
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
                                    const framework::Tensor** X,
                                    framework::Tensor** Out) {
  auto x_var = context.InputVar("X");
  auto out_var = context.OutputVar("Out");
  PADDLE_ENFORCE(x_var != nullptr,
                 "Cannot get input Variable X, variable name = %s",
                 context.op().Input("X"));
  PADDLE_ENFORCE(out_var != nullptr,
                 "Cannot get output Variable Out, variable name = %s",
                 context.op().Output("Out"));
  if (CanBeUsedBySelectedRows.count(context.op().Type())) {
    *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
    *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        out_var);
  } else {
    *X = context.Input<framework::Tensor>("X");
    *Out = context.Output<framework::Tensor>("Out");
  }

  PADDLE_ENFORCE(*Out != nullptr,
                 "Cannot get output tensor Out, variable name = %s",
                 context.op().Output("Out"));
}

98
template <ActBwdOpFwdDeps kDepValue>
99 100 101 102 103 104
inline void ExtractActivationGradTensor(
    const framework::ExecutionContext& context, const framework::Tensor** X,
    const framework::Tensor** Out, const framework::Tensor** dOut,
    framework::Tensor** dX) {
  auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
  auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
105 106 107 108 109 110 111 112
  const framework::Variable* out_var = nullptr;

  if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
    out_var = context.InputVar("Out");
    PADDLE_ENFORCE(out_var != nullptr,
                   "Cannot get input Variable Out, variable name = %s",
                   context.op().Input("Out"));
  }
113 114 115 116 117 118 119 120 121 122 123 124 125 126
  PADDLE_ENFORCE(out_grad_var != nullptr,
                 "Cannot get input Variable %s, variable name = %s",
                 framework::GradVarName("Out"),
                 context.op().Input(framework::GradVarName("Out")));
  PADDLE_ENFORCE(x_grad_var != nullptr,
                 "Cannot get output Variable %s, variable name = %s",
                 framework::GradVarName("X"),
                 context.op().Output(framework::GradVarName("X")));

  if (CanBeUsedBySelectedRows.count(context.op().Type())) {
    *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
        *out_grad_var);
    *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        x_grad_var);
127 128 129 130 131 132 133 134

    if (out_var) {
      *Out =
          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
    } else {
      *Out = *dOut;  // fake out
    }

135 136 137 138
  } else {
    *Out = context.Input<framework::Tensor>("Out");
    *dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
    *dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
139 140 141 142 143 144

    if (out_var) {
      *Out = &(out_var->Get<framework::LoDTensor>());
    } else {
      *Out = *dOut;  // fake out
    }
145
  }
146

147 148 149 150 151
  PADDLE_ENFORCE(*dX != nullptr,
                 "Cannot get output tensor %s, variable name = %s",
                 framework::GradVarName("X"),
                 context.op().Output(framework::GradVarName("X")));

152
  if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
C
chengduo 已提交
153 154
    auto x_var = context.InputVar("X");
    PADDLE_ENFORCE(x_var != nullptr,
155
                   "Cannot get input tensor X, variable name = %s",
C
chengduo 已提交
156 157
                   context.op().Input("X"));
    if (CanBeUsedBySelectedRows.count(context.op().Type())) {
158
      *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
C
chengduo 已提交
159
    } else {
160
      *X = context.Input<framework::Tensor>("X");
C
chengduo 已提交
161
    }
162 163 164 165 166
  } else {
    VLOG(10) << " Inplace activation of Op : " << context.op().Type();
    *X = *dX;
  }
}
C
chengduo 已提交
167

168 169 170 171 172
template <typename DeviceContext, typename Functor>
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
C
chengduo 已提交
173

174 175 176 177
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
C
chengduo 已提交
178
    Out->mutable_data<T>(context.GetPlace());
179 180 181

    auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
    auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
Q
QI JUN 已提交
182 183
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
184
    Functor functor;
185 186 187 188 189

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
190
    functor(*place, x, out);
Q
qijun 已提交
191 192 193
  }
};

Q
QI JUN 已提交
194
template <typename DeviceContext, typename Functor>
195 196
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
197
 public:
198
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
199
  void Compute(const framework::ExecutionContext& context) const override {
200 201 202
    const framework::Tensor *X, *Out, *dOut;
    framework::Tensor* dX = nullptr;
    X = Out = dOut = nullptr;
203 204
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
                                                    &dX);
Q
qijun 已提交
205
    dX->mutable_data<T>(context.GetPlace());
206 207 208 209
    auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
    auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
    auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
    auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
Q
QI JUN 已提交
210 211
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
212
    Functor functor;
213 214 215 216
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
217
    functor(*place, x, out, dout, dx);
Q
qijun 已提交
218 219 220
  }
};

221 222 223 224 225 226 227
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
D
dzhwinter 已提交
228 229 230 231 232 233 234 235

  /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient.
     For example, sigmoid op's gradient didn't involve x, so its output can
     reuse
     input memory. But abs op's gradient use x, it can not be inplaced.
     gradient did use x.
   */
  bool Inplace() const { return false; }
236 237
};

238
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
239
template <typename T>
240
struct SigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
241 242 243
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
244 245 246
  }
};

247
template <typename T>
248
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
249 250 251 252
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
Q
qijun 已提交
253
  }
254 255

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
256 257
};

258 259 260 261
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
F
fengjiayi 已提交
262
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
263 264 265 266 267 268 269 270 271 272
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
273 274
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
275
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
F
fengjiayi 已提交
276
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
277 278 279 280 281 282 283 284
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
285 286 287
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
288 289
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
F
fengjiayi 已提交
290
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
291
  }
292 293

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
294 295
};

Q
qijun 已提交
296
// exp(x) = e^x
297 298
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
299 300 301
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.exp();
Q
qijun 已提交
302 303 304
  }
};

305 306
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
307 308 309 310
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out;
Q
qijun 已提交
311
  }
312 313

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
314 315
};

Q
qijun 已提交
316
// relu(x) = max(x, 0)
Q
qijun 已提交
317
template <typename T>
318
struct ReluFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
319 320 321
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
Q
qijun 已提交
322 323
  }
};
Q
qijun 已提交
324

Q
qijun 已提交
325
template <typename T>
326
struct ReluGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
327 328 329
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
330
    dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
Q
qijun 已提交
331
  }
332 333

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
334
};
Q
qijun 已提交
335

C
Clementine 已提交
336 337 338 339 340
// gelu(x) = 0.5 * x *  (1 + erf(x / sqrt(2)))
template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yihua Xu 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
    !defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
    auto x_data = x.data();
    auto out_data = out.data();
    int n = std::min(x.size(), out.size());

    std::memset(out_data, 0, n * sizeof(T));
    math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
    math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
    for (int i = 0; i < n; i++) {
      out_data[i] += static_cast<T>(1);
    }
    math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
    for (int i = 0; i < n; i++) {
      out_data[i] *= static_cast<T>(0.5);
    }
#else
360
    auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
C
Clementine 已提交
361
    out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
Y
Yihua Xu 已提交
362
#endif
C
Clementine 已提交
363 364 365 366 367 368 369 370
  }
};

template <typename T>
struct GeluGradFunctor : BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
371 372 373 374 375 376
    auto first = static_cast<T>(0.5) *
                 (static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));

    auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
                  (-static_cast<T>(0.5) * x.square()).exp();
    dx.device(d) = dout * (first + second);
C
Clementine 已提交
377
  }
378 379

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
C
Clementine 已提交
380 381
};

382
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
383 384
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
385 386 387
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
Q
qijun 已提交
388 389 390 391
  }
};

template <typename T>
392
struct TanhGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
393 394 395 396
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
Q
qijun 已提交
397
  }
398 399

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
400 401
};

K
Kavya Srinet 已提交
402 403 404 405
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
406 407 408
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
K
Kavya Srinet 已提交
409 410 411 412 413
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
414 415 416 417
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x.tanh() * x.tanh());
K
Kavya Srinet 已提交
418
  }
419 420

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
K
Kavya Srinet 已提交
421 422
};

423 424 425 426 427 428 429 430 431
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
432 433
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
434 435
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
436
    out.device(d) = x * (temp1 + temp2);
437 438 439 440 441 442 443 444 445 446 447
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
448 449 450
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
451 452
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
453
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
454
  }
455 456

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
457 458
};

K
Kexin Zhao 已提交
459
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
460 461 462 463 464 465 466 467
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

F
fengjiayi 已提交
468 469
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
470 471 472
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
473
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
474 475 476 477 478 479 480 481 482
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
F
fengjiayi 已提交
483 484 485
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
486 487 488
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
489
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
490
  }
491 492

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
493 494
};

Q
qijun 已提交
495
// sqrt(x) = x^(1/2)
496 497
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
498 499 500
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
Q
qijun 已提交
501 502 503 504
  }
};

template <typename T>
505
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
506 507 508
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
C
chengduo 已提交
509
    dx.device(d) = static_cast<T>(0.5) * dout / out;
Q
qijun 已提交
510
  }
511 512

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
513 514
};

Z
zhoukunsheng 已提交
515 516 517 518 519 520 521 522 523 524 525 526 527 528
// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.rsqrt();
  }
};

template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
529
    dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
Z
zhoukunsheng 已提交
530 531 532
  }
};

D
dzhwinter 已提交
533 534 535
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
536 537 538
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
539 540 541 542 543
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
544 545 546
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
547
    dx.device(d) = static_cast<T>(0) / out;
D
dzhwinter 已提交
548
  }
549 550

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
D
dzhwinter 已提交
551 552 553 554 555
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
556 557
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Q
Qiao Longfei 已提交
558
    out.device(d) = x.floor();
D
dzhwinter 已提交
559 560 561
  }
};

C
add cos  
chengduoZH 已提交
562 563 564 565 566
template <typename T>
struct Sine {
  HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};

567 568 569 570 571 572 573
template <>
struct Sine<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(sin(static_cast<float>(val)));
  }
};

C
add cos  
chengduoZH 已提交
574 575 576 577 578
template <typename T>
struct Cosine {
  HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};

579 580 581 582 583 584 585
template <>
struct Cosine<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(cos(static_cast<float>(val)));
  }
};

C
add cos  
chengduoZH 已提交
586 587 588 589 590 591 592 593
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = -dout * x.unaryExpr(Sine<T>());
  }
594 595

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
C
add cos  
chengduoZH 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
};

// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Cosine<T>());
  }
};

// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.unaryExpr(Cosine<T>());
  }
615 616

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
C
add cos  
chengduoZH 已提交
617 618 619 620 621 622 623 624 625 626 627
};

// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Sine<T>());
  }
};

628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
template <typename T>
struct Acos {
  HOSTDEVICE T operator()(const T& val) const { return acos(val); }
};

template <>
struct Acos<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(acos(static_cast<float>(val)));
  }
};

// Acos(x) = acos(x)
template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Acos<T>());
  }
};

// acos'(x) = -1/sqrt(1-x^2)
template <typename T>
struct AcosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) =
        -dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
  }
658 659

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
};

template <typename T>
struct Asin {
  HOSTDEVICE T operator()(const T& val) const { return asin(val); }
};

template <>
struct Asin<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(asin(static_cast<float>(val)));
  }
};

// Asin(x) = asin(x)
template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Asin<T>());
  }
};

// asin'(x) = 1/sqrt(1-x^2)
template <typename T>
struct AsinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) =
        dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
  }
692 693

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
};

template <typename T>
struct Atan {
  HOSTDEVICE T operator()(const T& val) const { return atan(val); }
};

template <>
struct Atan<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(atan(static_cast<float>(val)));
  }
};

// Atan(x) = atan(x)
template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Atan<T>());
  }
};

// atan'(x) =  1 / (1 + x^2)
template <typename T>
struct AtanGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
  }
725 726

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
727 728
};

D
dzhwinter 已提交
729 730 731
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
732 733 734
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
D
dzhwinter 已提交
735 736 737
  }
};

Q
qijun 已提交
738
// abs(x) = |x|
739 740
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
741 742 743
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.abs();
Q
qijun 已提交
744 745 746
  }
};

747 748
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
749 750 751 752
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.sign();
753
  }
754 755

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; }
756 757
};

Q
qijun 已提交
758 759
// reciprocal(x) = 1 / x
template <typename T>
760
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
761 762 763
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
764 765 766
  }
};

767
template <typename T>
768
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
769 770 771 772
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
Q
qijun 已提交
773
  }
774 775

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
Q
qijun 已提交
776 777 778
};

// log(x) = natural logarithm of x
779 780
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
781 782 783
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.log();
Q
qijun 已提交
784 785 786
  }
};

787
template <typename T>
788
struct LogGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
789 790 791 792
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) / x);
Q
qijun 已提交
793
  }
794 795

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
Q
qijun 已提交
796 797 798
};

// square(x) = x^2
799 800
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
801 802 803
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
Q
qijun 已提交
804
  }
805
};
Q
qijun 已提交
806

807
template <typename T>
808
struct SquareGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
809 810 811 812
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(2) * x;
813
  }
814 815

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
816 817
};

818 819 820 821 822 823 824 825 826 827
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
828

F
fengjiayi 已提交
829 830 831
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
832
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
833 834 835
  }
};

836 837 838 839 840 841 842
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
F
fengjiayi 已提交
843 844 845 846
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
847 848
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
849
  }
850 851

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
852 853
};

854 855 856 857 858 859 860 861 862
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
863 864 865
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
866
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
867 868 869 870 871 872 873 874 875
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
876 877 878
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
879 880 881 882
    dx.device(d) =
        dout *
        ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
            .template cast<T>();
883
  }
884 885

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
886 887
};

K
kexinzhao 已提交
888 889 890 891 892 893 894
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
895 896
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
K
kexinzhao 已提交
897
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
898
    out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
K
kexinzhao 已提交
899 900 901 902 903 904 905 906 907
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
908 909 910
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
K
kexinzhao 已提交
911
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
912 913
    dx.device(d) =
        dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
K
kexinzhao 已提交
914
  }
915 916

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
K
kexinzhao 已提交
917 918
};

919 920
// softsign(x) = x / (1 + |x|)
template <typename T>
921
struct SoftsignFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
922 923 924
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
925 926 927 928 929 930
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
931
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
932 933 934
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
935
    dx.device(d) =
F
fengjiayi 已提交
936
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
937
  }
938 939

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
940 941
};

942 943 944 945 946 947
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
948

F
fengjiayi 已提交
949 950
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
951 952
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
953
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
954 955 956
  }
};

957 958 959 960 961 962
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
963 964 965
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
966
    auto tmp = static_cast<T>(threshold);
D
dzhwinter 已提交
967
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
F
fengjiayi 已提交
968
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
969
  }
970 971

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
972 973
};

K
Kavya Srinet 已提交
974 975 976 977 978 979
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
980

F
fengjiayi 已提交
981 982 983
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
984 985 986
  }
};

K
Kavya Srinet 已提交
987 988 989 990 991 992
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
993 994 995
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
996 997
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
998
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
F
fengjiayi 已提交
999
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
1000
  }
1001 1002

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
1003 1004
};

1005 1006 1007 1008 1009 1010
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
1011

F
fengjiayi 已提交
1012 1013 1014 1015 1016
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0)) +
                    (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                        .cwiseMin(static_cast<T>(0));
1017 1018 1019
  }
};

1020 1021 1022 1023 1024 1025
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
1026 1027 1028 1029
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
1030
                   dout * static_cast<T>(alpha) * x.exp() *
Y
Yu Yang 已提交
1031
                       (x < static_cast<T>(0)).template cast<T>();
1032
  }
1033 1034

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
1035 1036
};

Q
QI JUN 已提交
1037
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
1038 1039 1040 1041 1042 1043
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
1044 1045 1046
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
1047 1048 1049
  }
};

1050 1051 1052 1053 1054 1055
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
1056 1057 1058 1059
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(factor) *
C
chengduo 已提交
1060
                   x.pow(static_cast<T>(factor) - static_cast<T>(1));
1061
  }
1062 1063

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
1064 1065
};

1066 1067 1068 1069 1070 1071 1072
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
1073

F
fengjiayi 已提交
1074 1075 1076
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
1077
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
1078 1079 1080
  }
};

1081 1082 1083 1084 1085 1086 1087
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
1088

F
fengjiayi 已提交
1089 1090 1091
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
1092 1093 1094
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
F
fengjiayi 已提交
1095
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
1096
  }
1097 1098

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
Q
qijun 已提交
1099 1100
};

1101 1102 1103 1104 1105 1106 1107
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
1108 1109
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
1110
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
1111
    out.device(d) = (x > th).template cast<T>() * x;
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
1122 1123 1124
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
1125
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
1126
    dx.device(d) = dout * (x > th).template cast<T>();
1127
  }
1128 1129

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
1130 1131
};

1132 1133 1134 1135 1136 1137 1138 1139
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
1140 1141
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
1142
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
F
fengjiayi 已提交
1143 1144
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }
F
fengjiayi 已提交
1155 1156 1157 1158 1159 1160 1161
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
1162
  }
1163 1164

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
1165 1166
};

A
Abhinav Arora 已提交
1167 1168 1169 1170 1171 1172 1173
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
1174 1175 1176
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
A
Abhinav Arora 已提交
1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
1187 1188
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
1189
  void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
A
Abhinav Arora 已提交
1190
    auto temp1 = static_cast<T>(1) /
1191
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
1192
    auto out = x * temp1;
D
dzhwinter 已提交
1193 1194
    auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
    dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
A
Abhinav Arora 已提交
1195
  }
1196 1197

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
A
Abhinav Arora 已提交
1198 1199
};

Q
qijun 已提交
1200 1201
}  // namespace operators
}  // namespace paddle
1202

Z
zhoukunsheng 已提交
1203

1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
#define FOR_EACH_ACTIVATION_OP(__macro)                                       \
  __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);  \
  __macro(exp, Exp, ExpFunctor, ExpGradFunctor);                              \
  __macro(relu, Relu, ReluFunctor, ReluGradFunctor);                          \
  __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor);                          \
  __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor);                          \
  __macro(atan, Atan, AtanFunctor, AtanGradFunctor);                          \
  __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor);  \
  __macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);                          \
  __macro(abs, Abs, AbsFunctor, AbsGradFunctor);                              \
  __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor);                          \
  __macro(floor, Floor, FloorFunctor, ZeroGradFunctor);                       \
  __macro(cos, Cos, CosFunctor, CosGradFunctor);                              \
  __macro(acos, Acos, AcosFunctor, AcosGradFunctor);                          \
  __macro(sin, Sin, SinFunctor, SinGradFunctor);                              \
  __macro(asin, Asin, AsinFunctor, AsinGradFunctor);                          \
  __macro(round, Round, RoundFunctor, ZeroGradFunctor);                       \
  __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);  \
  __macro(log, Log, LogFunctor, LogGradFunctor);                              \
  __macro(square, Square, SquareFunctor, SquareGradFunctor);                  \
  __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor);                      \
  __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);         \
  __macro(pow, Pow, PowFunctor, PowGradFunctor);                              \
  __macro(stanh, STanh, STanhFunctor, STanhGradFunctor);                      \
  __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor);          \
  __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor);          \
  __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor);                      \
  __macro(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor);     \
  __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
  __macro(elu, ELU, ELUFunctor, ELUGradFunctor);                              \
  __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
  __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor,                      \
          HardSigmoidGradFunctor);                                            \
  __macro(swish, Swish, SwishFunctor, SwishGradFunctor);                      \
  __macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor,          \
          ThresholdedReluGradFunctor);