activation_op.h 28.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
Q
qijun 已提交
19

20 21 22 23
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

Q
qijun 已提交
24 25 26
namespace paddle {
namespace operators {

Q
QI JUN 已提交
27
template <typename DeviceContext, typename Functor>
28 29
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
30
 public:
31 32
  using T = typename Functor::ELEMENT_TYPE;

Q
qijun 已提交
33
  void Compute(const framework::ExecutionContext& context) const override {
Y
Update  
Yang Yu 已提交
34 35 36 37 38 39 40 41 42 43
    auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
                          "Cannot get input tensor X, variable name = %s",
                          context.op().Input("X"));

    auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
                            "Cannot get output tensor Out, variable name = %s",
                            context.op().Output("Out"));
    Out.mutable_data<T>(context.GetPlace());
    auto x = framework::EigenVector<T>::Flatten(X);
    auto out = framework::EigenVector<T>::Flatten(Out);
Q
QI JUN 已提交
44 45
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
46
    Functor functor;
47 48 49 50 51

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
52
    functor(*place, x, out);
Q
qijun 已提交
53 54 55
  }
};

Q
QI JUN 已提交
56
template <typename DeviceContext, typename Functor>
57 58
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
59
 public:
60
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
61 62
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
F
fengjiayi 已提交
63 64 65
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
Q
qijun 已提交
66 67 68
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    dX->mutable_data<T>(context.GetPlace());

F
fengjiayi 已提交
69
    auto dout = framework::EigenVector<T>::Flatten(*dOut);
Q
qijun 已提交
70
    auto x = framework::EigenVector<T>::Flatten(*X);
F
fengjiayi 已提交
71
    auto out = framework::EigenVector<T>::Flatten(*Out);
Q
qijun 已提交
72
    auto dx = framework::EigenVector<T>::Flatten(*dX);
Q
QI JUN 已提交
73 74
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
75
    Functor functor;
76 77 78 79
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
80
    functor(*place, x, out, dout, dx);
Q
qijun 已提交
81 82 83
  }
};

84 85 86 87 88 89 90 91 92
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

93
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
94
template <typename T>
95
struct SigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
96 97 98
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
99 100 101
  }
};

102
template <typename T>
103
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
104 105 106 107
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
Q
qijun 已提交
108 109 110
  }
};

111 112 113 114
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
F
fengjiayi 已提交
115
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
116 117 118 119 120 121 122 123 124 125
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
126 127
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
128
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
F
fengjiayi 已提交
129
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
130 131 132 133 134 135 136 137
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
138 139 140
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
141 142
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
F
fengjiayi 已提交
143
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
144 145 146
  }
};

Q
qijun 已提交
147
// exp(x) = e^x
148 149
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
150 151 152
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.exp();
Q
qijun 已提交
153 154 155
  }
};

156 157
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
158 159 160 161
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out;
Q
qijun 已提交
162 163 164
  }
};

Q
qijun 已提交
165
// relu(x) = max(x, 0)
Q
qijun 已提交
166
template <typename T>
167
struct ReluFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
168 169 170
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
Q
qijun 已提交
171 172
  }
};
Q
qijun 已提交
173

Q
qijun 已提交
174
template <typename T>
175
struct ReluGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
176 177 178 179
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
Q
qijun 已提交
180 181
  }
};
Q
qijun 已提交
182

183
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
184 185
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
186 187 188
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
Q
qijun 已提交
189 190 191 192
  }
};

template <typename T>
193
struct TanhGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
194 195 196 197
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
Q
qijun 已提交
198 199 200
  }
};

K
Kavya Srinet 已提交
201 202 203 204
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
205 206 207
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
K
Kavya Srinet 已提交
208 209 210 211 212
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
213 214 215 216
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x.tanh() * x.tanh());
K
Kavya Srinet 已提交
217 218 219
  }
};

220 221 222 223 224 225 226 227 228
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
229 230
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
231 232
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
233
    out.device(d) = x * (temp1 + temp2);
234 235 236 237 238 239 240 241 242 243 244
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
245 246 247
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
248 249
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
250
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
251 252 253
  }
};

K
Kexin Zhao 已提交
254
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
255 256 257 258 259 260 261 262
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

F
fengjiayi 已提交
263 264
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
265 266 267
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
268
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
269 270 271 272 273 274 275 276 277
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
F
fengjiayi 已提交
278 279 280
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
281 282 283
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
284
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
285 286 287
  }
};

Q
qijun 已提交
288
// sqrt(x) = x^(1/2)
289 290
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
291 292 293
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
Q
qijun 已提交
294 295 296 297
  }
};

template <typename T>
298
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
299 300 301 302 303
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    const Out out_conj = Eigen::numext::conj(out);
    dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
Q
qijun 已提交
304 305 306
  }
};

D
dzhwinter 已提交
307 308 309
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
310 311 312
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
313 314 315 316 317
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
318 319 320
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
321 322 323 324 325 326 327
    dx.device(d) = static_cast<T>(0) / x;
  }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
328 329
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Q
Qiao Longfei 已提交
330
    out.device(d) = x.floor();
D
dzhwinter 已提交
331 332 333
  }
};

C
add cos  
chengduoZH 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
template <typename T>
struct Sine {
  HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};

template <typename T>
struct Cosine {
  HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};

// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = -dout * x.unaryExpr(Sine<T>());
  }
};

// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Cosine<T>());
  }
};

// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.unaryExpr(Cosine<T>());
  }
};

// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr(Sine<T>());
  }
};

D
dzhwinter 已提交
382 383 384
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
385 386 387
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
D
dzhwinter 已提交
388 389 390
  }
};

Q
qijun 已提交
391
// abs(x) = |x|
392 393
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
394 395 396
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.abs();
Q
qijun 已提交
397 398 399
  }
};

400 401
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
402 403 404 405
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.sign();
406 407 408
  }
};

Q
qijun 已提交
409 410
// reciprocal(x) = 1 / x
template <typename T>
411
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
412 413 414
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
415 416 417
  }
};

418
template <typename T>
419
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
420 421 422 423
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
Q
qijun 已提交
424 425 426 427
  }
};

// log(x) = natural logarithm of x
428 429
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
430 431 432
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.log();
Q
qijun 已提交
433 434 435
  }
};

436
template <typename T>
437
struct LogGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
438 439 440 441
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) / x);
Q
qijun 已提交
442 443 444 445
  }
};

// square(x) = x^2
446 447
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
448 449 450
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
Q
qijun 已提交
451
  }
452
};
Q
qijun 已提交
453

454
template <typename T>
455
struct SquareGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
456 457 458 459
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(2) * x;
460 461 462
  }
};

463 464 465 466 467 468 469 470 471 472
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
473

F
fengjiayi 已提交
474 475 476
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
477
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
478 479 480
  }
};

481 482 483 484 485 486 487
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
F
fengjiayi 已提交
488 489 490 491
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
492 493
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
494 495 496
  }
};

497 498 499 500 501 502 503 504 505
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
506 507 508
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
509
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
510 511 512 513 514 515 516 517 518
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
519 520 521 522
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
523 524
                   ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
                       .template cast<T>();
525 526 527
  }
};

K
kexinzhao 已提交
528 529 530 531 532 533 534
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
535 536
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
K
kexinzhao 已提交
537
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
538
    out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
K
kexinzhao 已提交
539 540 541 542 543 544 545 546 547
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
548 549 550
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
K
kexinzhao 已提交
551
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
552 553
    dx.device(d) =
        dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
K
kexinzhao 已提交
554 555 556
  }
};

557 558
// softsign(x) = x / (1 + |x|)
template <typename T>
559
struct SoftsignFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
560 561 562
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
563 564 565 566 567 568
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
569
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
570 571 572
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
573
    dx.device(d) =
F
fengjiayi 已提交
574
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
575 576 577
  }
};

578 579 580 581 582 583
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
584

F
fengjiayi 已提交
585 586
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
587 588
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
589
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
590 591 592
  }
};

593 594 595 596 597 598
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
599 600 601
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
602 603
    auto tmp = static_cast<T>(threshold);
    auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
F
fengjiayi 已提交
604
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
605 606 607
  }
};

K
Kavya Srinet 已提交
608 609 610 611 612 613
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
614

F
fengjiayi 已提交
615 616 617
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
618 619 620
  }
};

K
Kavya Srinet 已提交
621 622 623 624 625 626
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
627 628 629
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
630 631
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
632
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
F
fengjiayi 已提交
633
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
634 635 636
  }
};

637 638 639 640 641 642
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
643

F
fengjiayi 已提交
644 645 646 647 648
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0)) +
                    (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                        .cwiseMin(static_cast<T>(0));
649 650 651
  }
};

652 653 654 655 656 657
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
658 659 660 661 662
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
                   dout * (out + static_cast<T>(alpha)) *
Y
Yu Yang 已提交
663
                       (x < static_cast<T>(0)).template cast<T>();
664 665 666
  }
};

Q
QI JUN 已提交
667
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
668 669 670 671 672 673
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
674 675 676
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
677 678 679
  }
};

680 681 682 683 684 685
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
686 687 688 689
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(factor) *
Y
Yu Yang 已提交
690
                   x.pow(static_cast<T>(factor - static_cast<T>(1)));
691 692 693
  }
};

694 695 696 697 698 699 700
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
701

F
fengjiayi 已提交
702 703 704
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
705
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
706 707 708
  }
};

709 710 711 712 713 714 715
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
716

F
fengjiayi 已提交
717 718 719
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
720 721 722
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
F
fengjiayi 已提交
723
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
724 725 726
  }
};

727 728 729 730 731 732 733
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
734 735
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
736
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
737
    out.device(d) = (x > th).template cast<T>() * x;
738 739 740 741 742 743 744 745 746 747
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
748 749 750
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
751
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
752
    dx.device(d) = dout * (x > th).template cast<T>();
753 754 755
  }
};

756 757 758 759 760 761 762 763
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
764 765
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
766
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
F
fengjiayi 已提交
767 768
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
769 770 771 772 773 774 775 776 777 778 779
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
780 781 782 783 784 785 786
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
787 788 789
  }
};

A
Abhinav Arora 已提交
790 791 792 793 794 795 796
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
797 798 799
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
A
Abhinav Arora 已提交
800 801 802 803 804 805 806 807 808 809
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
810 811 812
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
A
Abhinav Arora 已提交
813 814
    auto temp1 = static_cast<T>(1) /
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
F
fengjiayi 已提交
815 816
    auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
    dx.device(d) = dout * ((beta * out) + temp2);
A
Abhinav Arora 已提交
817 818 819
  }
};

Q
qijun 已提交
820 821
}  // namespace operators
}  // namespace paddle
822

823 824 825 826 827 828 829 830
#define FOR_EACH_KERNEL_FUNCTOR(__macro)                             \
  __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);     \
  __macro(exp, ExpFunctor, ExpGradFunctor);                          \
  __macro(tanh, TanhFunctor, TanhGradFunctor);                       \
  __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor);     \
  __macro(sqrt, SqrtFunctor, SqrtGradFunctor);                       \
  __macro(abs, AbsFunctor, AbsGradFunctor);                          \
D
dzhwinter 已提交
831 832
  __macro(ceil, CeilFunctor, ZeroGradFunctor);                       \
  __macro(floor, FloorFunctor, ZeroGradFunctor);                     \
C
add cos  
chengduoZH 已提交
833
  __macro(cos, CosFunctor, CosGradFunctor);                          \
C
add sin  
chengduoZH 已提交
834
  __macro(sin, SinFunctor, SinGradFunctor);                          \
D
dzhwinter 已提交
835
  __macro(round, RoundFunctor, ZeroGradFunctor);                     \
836 837 838 839 840 841 842 843 844 845 846 847 848 849 850
  __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);     \
  __macro(log, LogFunctor, LogGradFunctor);                          \
  __macro(square, SquareFunctor, SquareGradFunctor);                 \
  __macro(brelu, BReluFunctor, BReluGradFunctor);                    \
  __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor);          \
  __macro(pow, PowFunctor, PowGradFunctor);                          \
  __macro(stanh, STanhFunctor, STanhGradFunctor);                    \
  __macro(softplus, SoftplusFunctor, SoftplusGradFunctor);           \
  __macro(softsign, SoftsignFunctor, SoftsignGradFunctor);           \
  __macro(relu6, Relu6Functor, Relu6GradFunctor);                    \
  __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor);       \
  __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor);    \
  __macro(elu, ELUFunctor, ELUGradFunctor);                          \
  __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor);    \
  __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
A
Abhinav Arora 已提交
851
  __macro(swish, SwishFunctor, SwishGradFunctor);                    \
852
  __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);