activation_op.h 30.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
11 12

#pragma once
D
dzhwinter 已提交
13 14 15
#include <glog/logging.h>
#include <string>
#include <unordered_set>
16 17
#include <utility>
#include <vector>
18

Y
Yi Wang 已提交
19 20 21
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
22
#include "paddle/fluid/platform/float16.h"
Q
qijun 已提交
23

24 25 26 27
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

Q
qijun 已提交
28 29 30
namespace paddle {
namespace operators {

D
dzhwinter 已提交
31 32 33 34 35 36 37 38 39 40
/* Use ugly global variable, for the using in python layer side
   Please refer to the layer_helper.py and get the details.
 */
static std::unordered_set<std::string> InplaceOpSet = {
    "sigmoid", "exp",        "relu",  "tanh",      "sqrt",         "ceil",
    "floor",   "reciprocal", "relu6", "soft_relu", "hard_sigmoid",
};

static bool IsInplace(std::string op) { return InplaceOpSet.count(op); }

Q
QI JUN 已提交
41
template <typename DeviceContext, typename Functor>
42 43
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
44
 public:
45 46
  using T = typename Functor::ELEMENT_TYPE;

Q
qijun 已提交
47
  void Compute(const framework::ExecutionContext& context) const override {
Y
Update  
Yang Yu 已提交
48 49 50 51 52 53 54 55 56 57
    auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
                          "Cannot get input tensor X, variable name = %s",
                          context.op().Input("X"));

    auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
                            "Cannot get output tensor Out, variable name = %s",
                            context.op().Output("Out"));
    Out.mutable_data<T>(context.GetPlace());
    auto x = framework::EigenVector<T>::Flatten(X);
    auto out = framework::EigenVector<T>::Flatten(Out);
Q
QI JUN 已提交
58 59
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
60
    Functor functor;
61 62 63 64 65

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
66
    functor(*place, x, out);
Q
qijun 已提交
67 68 69
  }
};

Q
QI JUN 已提交
70
template <typename DeviceContext, typename Functor>
71 72
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
73
 public:
74
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
75
  void Compute(const framework::ExecutionContext& context) const override {
F
fengjiayi 已提交
76 77 78
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
Q
qijun 已提交
79 80 81
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    dX->mutable_data<T>(context.GetPlace());

F
fengjiayi 已提交
82 83
    auto dout = framework::EigenVector<T>::Flatten(*dOut);
    auto out = framework::EigenVector<T>::Flatten(*Out);
Q
qijun 已提交
84
    auto dx = framework::EigenVector<T>::Flatten(*dX);
Q
QI JUN 已提交
85 86
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
87
    Functor functor;
88 89 90 91
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
D
dzhwinter 已提交
92 93 94 95 96 97 98 99 100 101
    bool inplace = functor.Inplace();
    if (!inplace) {
      auto* X = context.Input<framework::Tensor>("X");
      auto x = framework::EigenVector<T>::Flatten(*X);
      functor(*place, x, out, dout, dx);
    } else {
      VLOG(10) << " Inplace activation ";
      auto x = framework::EigenVector<T>::Flatten(*dX);
      functor(*place, x, out, dout, dx);
    }
Q
qijun 已提交
102 103 104
  }
};

105 106 107 108 109 110 111
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
D
dzhwinter 已提交
112 113 114 115 116 117 118 119

  /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient.
     For example, sigmoid op's gradient didn't involve x, so its output can
     reuse
     input memory. But abs op's gradient use x, it can not be inplaced.
     gradient did use x.
   */
  bool Inplace() const { return false; }
120 121
};

122
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
123
template <typename T>
124
struct SigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
125 126 127
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
128 129 130
  }
};

131
template <typename T>
132
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
133
  bool Inplace() const { return IsInplace("sigmoid"); }
F
fengjiayi 已提交
134 135 136 137
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
Q
qijun 已提交
138 139 140
  }
};

141 142 143 144
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
F
fengjiayi 已提交
145
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
146 147 148 149 150 151 152 153 154 155
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
156 157
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
158
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
F
fengjiayi 已提交
159
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
160 161 162 163 164 165 166 167
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
168 169 170
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
171 172
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
F
fengjiayi 已提交
173
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
174 175 176
  }
};

Q
qijun 已提交
177
// exp(x) = e^x
178 179
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
180 181 182
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.exp();
Q
qijun 已提交
183 184 185
  }
};

186 187
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
188
  bool Inplace() const { return IsInplace("exp"); }
F
fengjiayi 已提交
189 190 191 192
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out;
Q
qijun 已提交
193 194 195
  }
};

Q
qijun 已提交
196
// relu(x) = max(x, 0)
Q
qijun 已提交
197
template <typename T>
198
struct ReluFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
199 200 201
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
Q
qijun 已提交
202 203
  }
};
Q
qijun 已提交
204

Q
qijun 已提交
205
template <typename T>
206
struct ReluGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
207
  bool Inplace() const { return IsInplace("relu"); }
F
fengjiayi 已提交
208 209 210
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
211
    dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
Q
qijun 已提交
212 213
  }
};
Q
qijun 已提交
214

215
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
216 217
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
218 219 220
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
Q
qijun 已提交
221 222 223 224
  }
};

template <typename T>
225
struct TanhGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
226
  bool Inplace() const { return IsInplace("tanh"); }
F
fengjiayi 已提交
227 228 229 230
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
Q
qijun 已提交
231 232 233
  }
};

K
Kavya Srinet 已提交
234 235 236 237
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
238 239 240
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
K
Kavya Srinet 已提交
241 242 243 244 245
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
246 247 248 249
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x.tanh() * x.tanh());
K
Kavya Srinet 已提交
250 251 252
  }
};

253 254 255 256 257 258 259 260 261
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
262 263
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
264 265
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
266
    out.device(d) = x * (temp1 + temp2);
267 268 269 270 271 272 273 274 275 276 277
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
278 279 280
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
281 282
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
283
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
284 285 286
  }
};

K
Kexin Zhao 已提交
287
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
288 289 290 291 292 293 294 295
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

F
fengjiayi 已提交
296 297
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
298 299 300
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
301
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
302 303 304 305 306 307 308 309 310
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
F
fengjiayi 已提交
311 312 313
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
314 315 316
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
317
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
318 319 320
  }
};

Q
qijun 已提交
321
// sqrt(x) = x^(1/2)
322 323
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
324 325 326
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
Q
qijun 已提交
327 328 329 330
  }
};

template <typename T>
331
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
332
  bool Inplace() const { return IsInplace("sqrt"); }
F
fengjiayi 已提交
333 334 335
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
336 337
    const Out out_conj = Eigen::numext::conj(out);
    dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
Q
qijun 已提交
338 339 340
  }
};

D
dzhwinter 已提交
341 342 343
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
344 345 346
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
347 348 349 350 351
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
352
  bool Inplace() const { return IsInplace("ceil"); }
F
fengjiayi 已提交
353 354 355
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
356
    dx.device(d) = static_cast<T>(0) / out;
D
dzhwinter 已提交
357 358 359 360 361 362
  }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
363 364
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Q
Qiao Longfei 已提交
365
    out.device(d) = x.floor();
D
dzhwinter 已提交
366 367 368
  }
};

C
add cos  
chengduoZH 已提交
369 370 371 372 373
template <typename T>
struct Sine {
  HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};

374 375 376 377 378 379 380
template <>
struct Sine<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(sin(static_cast<float>(val)));
  }
};

C
add cos  
chengduoZH 已提交
381 382 383 384 385
template <typename T>
struct Cosine {
  HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};

386 387 388 389 390 391 392
template <>
struct Cosine<platform::float16> {
  HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
    return platform::float16(cos(static_cast<float>(val)));
  }
};

C
add cos  
chengduoZH 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = -dout * x.unaryExpr(Sine<T>());
  }
};

// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Cosine<T>());
  }
};

// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.unaryExpr(Cosine<T>());
  }
};

// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Sine<T>());
  }
};

D
dzhwinter 已提交
431 432 433
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
434 435 436
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
D
dzhwinter 已提交
437 438 439
  }
};

Q
qijun 已提交
440
// abs(x) = |x|
441 442
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
443 444 445
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.abs();
Q
qijun 已提交
446 447 448
  }
};

449 450
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
451 452 453 454
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.sign();
455 456 457
  }
};

Q
qijun 已提交
458 459
// reciprocal(x) = 1 / x
template <typename T>
460
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
461 462 463
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
464 465 466
  }
};

467
template <typename T>
468
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
D
dzhwinter 已提交
469
  bool Inplace() const { return IsInplace("reciprocal"); }
F
fengjiayi 已提交
470 471 472 473
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
Q
qijun 已提交
474 475 476 477
  }
};

// log(x) = natural logarithm of x
478 479
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
480 481 482
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.log();
Q
qijun 已提交
483 484 485
  }
};

486
template <typename T>
487
struct LogGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
488 489 490 491
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) / x);
Q
qijun 已提交
492 493 494 495
  }
};

// square(x) = x^2
496 497
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
498 499 500
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
Q
qijun 已提交
501
  }
502
};
Q
qijun 已提交
503

504
template <typename T>
505
struct SquareGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
506 507 508 509
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(2) * x;
510 511 512
  }
};

513 514 515 516 517 518 519 520 521 522
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
523

F
fengjiayi 已提交
524 525 526
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
527
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
528 529 530
  }
};

531 532 533 534 535 536 537
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
F
fengjiayi 已提交
538 539 540 541
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
542 543
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
544 545 546
  }
};

547 548 549 550 551 552 553 554 555
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
556 557 558
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
559
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
560 561 562 563 564 565 566 567 568
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
D
dzhwinter 已提交
569
  bool Inplace() const { return IsInplace("relu6"); }
F
fengjiayi 已提交
570 571 572
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
573 574 575 576
    dx.device(d) =
        dout *
        ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
            .template cast<T>();
577 578 579
  }
};

K
kexinzhao 已提交
580 581 582 583 584 585 586
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
587 588
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
K
kexinzhao 已提交
589
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
590
    out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
K
kexinzhao 已提交
591 592 593 594 595 596 597 598 599
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
600 601 602
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
K
kexinzhao 已提交
603
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
604 605
    dx.device(d) =
        dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
K
kexinzhao 已提交
606 607 608
  }
};

609 610
// softsign(x) = x / (1 + |x|)
template <typename T>
611
struct SoftsignFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
612 613 614
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
615 616 617 618 619 620
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
621
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
622 623 624
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
625
    dx.device(d) =
F
fengjiayi 已提交
626
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
627 628 629
  }
};

630 631 632 633 634 635
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
636

F
fengjiayi 已提交
637 638
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
639 640
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
641
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
642 643 644
  }
};

645 646 647 648 649 650
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
D
dzhwinter 已提交
651
  bool Inplace() const { return IsInplace("soft_relu"); }
F
fengjiayi 已提交
652 653 654
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
655
    auto tmp = static_cast<T>(threshold);
D
dzhwinter 已提交
656
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
F
fengjiayi 已提交
657
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
658 659 660
  }
};

K
Kavya Srinet 已提交
661 662 663 664 665 666
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
667

F
fengjiayi 已提交
668 669 670
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
671 672 673
  }
};

K
Kavya Srinet 已提交
674 675 676 677 678 679
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
680 681 682
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
683 684
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
685
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
F
fengjiayi 已提交
686
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
687 688 689
  }
};

690 691 692 693 694 695
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
696

F
fengjiayi 已提交
697 698 699 700 701
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0)) +
                    (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                        .cwiseMin(static_cast<T>(0));
702 703 704
  }
};

705 706 707 708 709 710
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
711 712 713 714 715
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
                   dout * (out + static_cast<T>(alpha)) *
Y
Yu Yang 已提交
716
                       (x < static_cast<T>(0)).template cast<T>();
717 718 719
  }
};

Q
QI JUN 已提交
720
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
721 722 723 724 725 726
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
727 728 729
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
730 731 732
  }
};

733 734 735 736 737 738
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
739 740 741 742
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(factor) *
743
                   x.pow(static_cast<T>(factor - static_cast<T>(1)));
744 745 746
  }
};

747 748 749 750 751 752 753
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
754

F
fengjiayi 已提交
755 756 757
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
758
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
759 760 761
  }
};

762 763 764 765 766 767 768
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
769

F
fengjiayi 已提交
770 771 772
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
773 774 775
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
F
fengjiayi 已提交
776
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
777 778 779
  }
};

780 781 782 783 784 785 786
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
787 788
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
789
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
790
    out.device(d) = (x > th).template cast<T>() * x;
791 792 793 794 795 796 797 798 799 800
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
801 802 803
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
804
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
805
    dx.device(d) = dout * (x > th).template cast<T>();
806 807 808
  }
};

809 810 811 812 813 814 815 816
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
817 818
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
819
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
F
fengjiayi 已提交
820 821
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
822 823 824 825 826 827 828 829 830 831
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }
D
dzhwinter 已提交
832
  bool Inplace() { return IsInplace("hard_sigmoid"); }
F
fengjiayi 已提交
833 834 835 836 837 838 839
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
840 841 842
  }
};

A
Abhinav Arora 已提交
843 844 845 846 847 848 849
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
850 851 852
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
A
Abhinav Arora 已提交
853 854 855 856 857 858 859 860 861 862
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
863 864 865
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
A
Abhinav Arora 已提交
866
    auto temp1 = static_cast<T>(1) /
867
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
D
dzhwinter 已提交
868 869
    auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
    dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
A
Abhinav Arora 已提交
870 871 872
  }
};

Q
qijun 已提交
873 874
}  // namespace operators
}  // namespace paddle
875

876 877 878 879
#define FOR_EACH_KERNEL_FUNCTOR(__macro)                             \
  __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);     \
  __macro(exp, ExpFunctor, ExpGradFunctor);                          \
880
  __macro(relu, ReluFunctor, ReluGradFunctor);                       \
881 882 883 884
  __macro(tanh, TanhFunctor, TanhGradFunctor);                       \
  __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor);     \
  __macro(sqrt, SqrtFunctor, SqrtGradFunctor);                       \
  __macro(abs, AbsFunctor, AbsGradFunctor);                          \
D
dzhwinter 已提交
885 886
  __macro(ceil, CeilFunctor, ZeroGradFunctor);                       \
  __macro(floor, FloorFunctor, ZeroGradFunctor);                     \
C
add cos  
chengduoZH 已提交
887
  __macro(cos, CosFunctor, CosGradFunctor);                          \
C
add sin  
chengduoZH 已提交
888
  __macro(sin, SinFunctor, SinGradFunctor);                          \
D
dzhwinter 已提交
889
  __macro(round, RoundFunctor, ZeroGradFunctor);                     \
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
  __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);     \
  __macro(log, LogFunctor, LogGradFunctor);                          \
  __macro(square, SquareFunctor, SquareGradFunctor);                 \
  __macro(brelu, BReluFunctor, BReluGradFunctor);                    \
  __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor);          \
  __macro(pow, PowFunctor, PowGradFunctor);                          \
  __macro(stanh, STanhFunctor, STanhGradFunctor);                    \
  __macro(softplus, SoftplusFunctor, SoftplusGradFunctor);           \
  __macro(softsign, SoftsignFunctor, SoftsignGradFunctor);           \
  __macro(relu6, Relu6Functor, Relu6GradFunctor);                    \
  __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor);       \
  __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor);    \
  __macro(elu, ELUFunctor, ELUGradFunctor);                          \
  __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor);    \
  __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
A
Abhinav Arora 已提交
905
  __macro(swish, SwishFunctor, SwishGradFunctor);                    \
906
  __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);