activation_op.h 27.0 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

Q
QI JUN 已提交
22
template <typename DeviceContext, typename Functor>
23 24
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
25
 public:
26 27
  using T = typename Functor::ELEMENT_TYPE;

Q
qijun 已提交
28 29
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
F
fengjiayi 已提交
30 31
    auto* Out = context.Output<framework::Tensor>("Out");
    Out->mutable_data<T>(context.GetPlace());
Q
qijun 已提交
32 33

    auto x = framework::EigenVector<T>::Flatten(*X);
F
fengjiayi 已提交
34
    auto out = framework::EigenVector<T>::Flatten(*Out);
Q
QI JUN 已提交
35 36
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
37
    Functor functor;
38 39 40 41 42

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
43
    functor(*place, x, out);
Q
qijun 已提交
44 45 46
  }
};

Q
QI JUN 已提交
47
template <typename DeviceContext, typename Functor>
48 49
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
50
 public:
51
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
52 53
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
F
fengjiayi 已提交
54 55 56
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
Q
qijun 已提交
57 58 59
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    dX->mutable_data<T>(context.GetPlace());

F
fengjiayi 已提交
60
    auto dout = framework::EigenVector<T>::Flatten(*dOut);
Q
qijun 已提交
61
    auto x = framework::EigenVector<T>::Flatten(*X);
F
fengjiayi 已提交
62
    auto out = framework::EigenVector<T>::Flatten(*Out);
Q
qijun 已提交
63
    auto dx = framework::EigenVector<T>::Flatten(*dX);
Q
QI JUN 已提交
64 65
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
66
    Functor functor;
67 68 69 70
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
71
    functor(*place, x, out, dout, dx);
Q
qijun 已提交
72 73 74
  }
};

75 76 77 78 79 80 81 82 83
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

84
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
85
template <typename T>
86
struct SigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
87 88 89
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
90 91 92
  }
};

93
template <typename T>
94
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
95 96 97 98
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
Q
qijun 已提交
99 100 101
  }
};

102 103 104 105
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
F
fengjiayi 已提交
106
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
107 108 109 110 111 112 113 114 115 116
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
117 118
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
119
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
F
fengjiayi 已提交
120
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
121 122 123 124 125 126 127 128
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
129 130 131
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
132 133
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
F
fengjiayi 已提交
134
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
135 136 137
  }
};

Q
qijun 已提交
138
// exp(x) = e^x
139 140
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
141 142 143
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.exp();
Q
qijun 已提交
144 145 146
  }
};

147 148
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
149 150 151 152
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out;
Q
qijun 已提交
153 154 155
  }
};

Q
qijun 已提交
156
// relu(x) = max(x, 0)
Q
qijun 已提交
157
template <typename T>
158
struct ReluFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
159 160 161
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
Q
qijun 已提交
162 163
  }
};
Q
qijun 已提交
164

Q
qijun 已提交
165
template <typename T>
166
struct ReluGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
167 168 169 170
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
Q
qijun 已提交
171 172
  }
};
Q
qijun 已提交
173

174
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
175 176
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
177 178 179
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
Q
qijun 已提交
180 181 182 183
  }
};

template <typename T>
184
struct TanhGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
185 186 187 188
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
Q
qijun 已提交
189 190 191
  }
};

K
Kavya Srinet 已提交
192 193 194 195
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
196 197 198
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
K
Kavya Srinet 已提交
199 200 201 202 203
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
204 205 206 207
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x.tanh() * x.tanh());
K
Kavya Srinet 已提交
208 209 210
  }
};

211 212 213 214 215 216 217 218 219
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
220 221
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
222 223
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
224
    out.device(d) = x * (temp1 + temp2);
225 226 227 228 229 230 231 232 233 234 235
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
236 237 238
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
239 240
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
241
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
242 243 244
  }
};

K
Kexin Zhao 已提交
245
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
246 247 248 249 250 251 252 253
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

F
fengjiayi 已提交
254 255
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
256 257 258
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
259
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
260 261 262 263 264 265 266 267 268
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
F
fengjiayi 已提交
269 270 271
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
272 273 274
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
275
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
276 277 278
  }
};

Q
qijun 已提交
279
// sqrt(x) = x^(1/2)
280 281
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
282 283 284
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
Q
qijun 已提交
285 286 287 288
  }
};

template <typename T>
289
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
290 291 292 293 294
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    const Out out_conj = Eigen::numext::conj(out);
    dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
Q
qijun 已提交
295 296 297
  }
};

D
dzhwinter 已提交
298 299 300
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
301 302 303
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
304 305 306 307 308
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
309 310 311
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
312 313 314 315 316 317 318
    dx.device(d) = static_cast<T>(0) / x;
  }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
319 320 321
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
322 323 324 325 326 327
  }
};

// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
328 329 330
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
D
dzhwinter 已提交
331 332 333
  }
};

Q
qijun 已提交
334
// abs(x) = |x|
335 336
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
337 338 339
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.abs();
Q
qijun 已提交
340 341 342
  }
};

343 344
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
345 346 347 348
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.sign();
349 350 351
  }
};

Q
qijun 已提交
352 353
// reciprocal(x) = 1 / x
template <typename T>
354
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
355 356 357
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
358 359 360
  }
};

361
template <typename T>
362
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
363 364 365 366
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
Q
qijun 已提交
367 368 369 370
  }
};

// log(x) = natural logarithm of x
371 372
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
373 374 375
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.log();
Q
qijun 已提交
376 377 378
  }
};

379
template <typename T>
380
struct LogGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
381 382 383 384
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) / x);
Q
qijun 已提交
385 386 387 388
  }
};

// square(x) = x^2
389 390
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
391 392 393
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
Q
qijun 已提交
394
  }
395
};
Q
qijun 已提交
396

397
template <typename T>
398
struct SquareGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
399 400 401 402
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(2) * x;
403 404 405
  }
};

406 407 408 409 410 411 412 413 414 415
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
416

F
fengjiayi 已提交
417 418 419
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
420
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
421 422 423
  }
};

424 425 426 427 428 429 430
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
F
fengjiayi 已提交
431 432 433 434
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
435 436
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
437 438 439
  }
};

440 441 442 443 444 445 446 447 448
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
449 450 451
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
452
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
453 454 455 456 457 458 459 460 461
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
462 463 464 465
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
466 467
                   ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
                       .template cast<T>();
468 469 470
  }
};

K
kexinzhao 已提交
471 472 473 474 475 476 477
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
478 479
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
K
kexinzhao 已提交
480
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
481
    out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
K
kexinzhao 已提交
482 483 484 485 486 487 488 489 490
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
491 492 493
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
K
kexinzhao 已提交
494
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
495 496
    dx.device(d) =
        dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
K
kexinzhao 已提交
497 498 499
  }
};

500 501
// softsign(x) = x / (1 + |x|)
template <typename T>
502
struct SoftsignFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
503 504 505
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
506 507 508 509 510 511
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
512
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
513 514 515
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
516
    dx.device(d) =
F
fengjiayi 已提交
517
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
518 519 520
  }
};

521 522 523 524 525 526
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
527

F
fengjiayi 已提交
528 529
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
530 531
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
532
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
533 534 535
  }
};

536 537 538 539 540 541
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
542 543 544
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
545 546
    auto tmp = static_cast<T>(threshold);
    auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
F
fengjiayi 已提交
547
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
548 549 550
  }
};

K
Kavya Srinet 已提交
551 552 553 554 555 556
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
557

F
fengjiayi 已提交
558 559 560
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
561 562 563
  }
};

K
Kavya Srinet 已提交
564 565 566 567 568 569
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
570 571 572
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
573 574
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
575
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
F
fengjiayi 已提交
576
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
577 578 579
  }
};

580 581 582 583 584 585
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
586

F
fengjiayi 已提交
587 588 589 590 591
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0)) +
                    (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                        .cwiseMin(static_cast<T>(0));
592 593 594
  }
};

595 596 597 598 599 600
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
601 602 603 604 605
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
                   dout * (out + static_cast<T>(alpha)) *
Y
Yu Yang 已提交
606
                       (x < static_cast<T>(0)).template cast<T>();
607 608 609
  }
};

Q
QI JUN 已提交
610
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
611 612 613 614 615 616
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
617 618 619
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
620 621 622
  }
};

623 624 625 626 627 628
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
629 630 631 632
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(factor) *
Y
Yu Yang 已提交
633
                   x.pow(static_cast<T>(factor - static_cast<T>(1)));
634 635 636
  }
};

637 638 639 640 641 642 643
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
644

F
fengjiayi 已提交
645 646 647
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
648
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
649 650 651
  }
};

652 653 654 655 656 657 658
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
659

F
fengjiayi 已提交
660 661 662
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
663 664 665
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
F
fengjiayi 已提交
666
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
667 668 669
  }
};

670 671 672 673 674 675 676
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
677 678
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
679
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
680
    out.device(d) = (x > th).template cast<T>() * x;
681 682 683 684 685 686 687 688 689 690
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
691 692 693
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
694
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
695
    dx.device(d) = dout * (x > th).template cast<T>();
696 697 698
  }
};

699 700 701 702 703 704 705 706
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
707 708
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
709
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
F
fengjiayi 已提交
710 711
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
712 713 714 715 716 717 718 719 720 721 722
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
723 724 725 726 727 728 729
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
730 731 732
  }
};

A
Abhinav Arora 已提交
733 734 735 736 737 738 739
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
740 741 742
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
A
Abhinav Arora 已提交
743 744 745 746 747 748 749 750 751 752
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
753 754 755
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
A
Abhinav Arora 已提交
756 757
    auto temp1 = static_cast<T>(1) /
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
F
fengjiayi 已提交
758 759
    auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
    dx.device(d) = dout * ((beta * out) + temp2);
A
Abhinav Arora 已提交
760 761 762
  }
};

Q
qijun 已提交
763 764
}  // namespace operators
}  // namespace paddle
765

766 767 768 769 770 771 772 773 774
#define FOR_EACH_KERNEL_FUNCTOR(__macro)                             \
  __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);     \
  __macro(exp, ExpFunctor, ExpGradFunctor);                          \
  __macro(relu, ReluFunctor, ReluGradFunctor);                       \
  __macro(tanh, TanhFunctor, TanhGradFunctor);                       \
  __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor);     \
  __macro(sqrt, SqrtFunctor, SqrtGradFunctor);                       \
  __macro(abs, AbsFunctor, AbsGradFunctor);                          \
D
dzhwinter 已提交
775 776 777
  __macro(ceil, CeilFunctor, ZeroGradFunctor);                       \
  __macro(floor, FloorFunctor, ZeroGradFunctor);                     \
  __macro(round, RoundFunctor, ZeroGradFunctor);                     \
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
  __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);     \
  __macro(log, LogFunctor, LogGradFunctor);                          \
  __macro(square, SquareFunctor, SquareGradFunctor);                 \
  __macro(brelu, BReluFunctor, BReluGradFunctor);                    \
  __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor);          \
  __macro(pow, PowFunctor, PowGradFunctor);                          \
  __macro(stanh, STanhFunctor, STanhGradFunctor);                    \
  __macro(softplus, SoftplusFunctor, SoftplusGradFunctor);           \
  __macro(softsign, SoftsignFunctor, SoftsignGradFunctor);           \
  __macro(relu6, Relu6Functor, Relu6GradFunctor);                    \
  __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor);       \
  __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor);    \
  __macro(elu, ELUFunctor, ELUGradFunctor);                          \
  __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor);    \
  __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
A
Abhinav Arora 已提交
793
  __macro(swish, SwishFunctor, SwishGradFunctor);                    \
794
  __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);