activation_op.h 26.0 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

Q
QI JUN 已提交
22
template <typename DeviceContext, typename Functor>
23 24
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
25
 public:
26 27
  using T = typename Functor::ELEMENT_TYPE;

Q
qijun 已提交
28 29 30 31 32 33 34
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
    auto* Y = context.Output<framework::Tensor>("Y");
    Y->mutable_data<T>(context.GetPlace());

    auto x = framework::EigenVector<T>::Flatten(*X);
    auto y = framework::EigenVector<T>::Flatten(*Y);
Q
QI JUN 已提交
35 36
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
37
    Functor functor;
38 39 40 41 42

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
Q
QI JUN 已提交
43
    functor(*place, x, y);
Q
qijun 已提交
44 45 46
  }
};

Q
QI JUN 已提交
47
template <typename DeviceContext, typename Functor>
48 49
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
50
 public:
51
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
52 53 54 55 56 57 58 59 60 61 62
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
    auto* Y = context.Input<framework::Tensor>("Y");
    auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    dX->mutable_data<T>(context.GetPlace());

    auto dy = framework::EigenVector<T>::Flatten(*dY);
    auto x = framework::EigenVector<T>::Flatten(*X);
    auto y = framework::EigenVector<T>::Flatten(*Y);
    auto dx = framework::EigenVector<T>::Flatten(*dX);
Q
QI JUN 已提交
63 64
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
65
    Functor functor;
66 67 68 69
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
Q
QI JUN 已提交
70
    functor(*place, x, y, dy, dx);
Q
qijun 已提交
71 72 73
  }
};

74 75 76 77 78 79 80 81 82
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

83
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
84
template <typename T>
85
struct SigmoidFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
86
  template <typename Device, typename X, typename Y>
87
  void operator()(Device d, X x, Y y) const {
88
    y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
89 90 91
  }
};

92
template <typename T>
93
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
94
  template <typename Device, typename X, typename Y, typename dY, typename dX>
95
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
96
    dx.device(d) = dy * y * (static_cast<T>(1) - y);
Q
qijun 已提交
97 98 99
  }
};

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1]
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
        dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
  }
};

Q
qijun 已提交
135
// exp(x) = e^x
136 137
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
138
  template <typename Device, typename X, typename Y>
139
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
140 141 142 143
    y.device(d) = x.exp();
  }
};

144 145
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
146
  template <typename Device, typename X, typename Y, typename dY, typename dX>
147
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Q
qijun 已提交
148
    dx.device(d) = dy * y;
Q
qijun 已提交
149 150 151
  }
};

Q
qijun 已提交
152
// relu(x) = max(x, 0)
Q
qijun 已提交
153
template <typename T>
154
struct ReluFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
155
  template <typename Device, typename X, typename Y>
156
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
157 158 159
    y.device(d) = x.cwiseMax(static_cast<T>(0));
  }
};
Q
qijun 已提交
160

Q
qijun 已提交
161
template <typename T>
162
struct ReluGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
163
  template <typename Device, typename X, typename Y, typename dY, typename dX>
164
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Q
qijun 已提交
165 166 167
    dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
  }
};
Q
qijun 已提交
168

169
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
170 171
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
172
  template <typename Device, typename X, typename Y>
173
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
174 175 176 177 178
    y.device(d) = x.tanh();
  }
};

template <typename T>
179
struct TanhGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
180
  template <typename Device, typename X, typename Y, typename dY, typename dX>
181
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
182
    dx.device(d) = dy * (static_cast<T>(1) - y * y);
Q
qijun 已提交
183 184 185
  }
};

K
Kavya Srinet 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    y.device(d) = x - x.tanh();
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
    dx.device(d) = dy * (x.tanh() * x.tanh());
  }
};

204 205 206 207 208 209 210 211 212 213 214
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
215 216
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
217 218 219 220 221 222 223 224 225 226 227 228 229 230
    y.device(d) = x * (temp1 + temp2);
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
231 232
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
233 234 235 236
    dx.device(d) = dy * (temp1 + temp2).template cast<T>();
  }
};

K
Kexin Zhao 已提交
237
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
238 239 240 241 242 243 244 245 246 247
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
248 249 250 251
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
    y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
252 253 254 255 256 257 258 259 260 261 262
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
263 264 265
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
266 267 268 269
    dx.device(d) = dy * (temp1 + temp2).template cast<T>();
  }
};

Q
qijun 已提交
270
// sqrt(x) = x^(1/2)
271 272
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
273
  template <typename Device, typename X, typename Y>
274
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
275 276 277 278 279
    y.device(d) = x.sqrt();
  }
};

template <typename T>
280
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
281
  template <typename Device, typename X, typename Y, typename dY, typename dX>
282
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
283
    const Y y_conj = Eigen::numext::conj(y);
Q
qijun 已提交
284 285 286 287
    dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
  }
};

D
dzhwinter 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    y.device(d) = x.ceil();
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
    dx.device(d) = static_cast<T>(0) / x;
  }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    y.device(d) = x.ceil();
  }
};

// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    y.device(d) = x.round();
  }
};

Q
qijun 已提交
323
// abs(x) = |x|
324 325
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
326
  template <typename Device, typename X, typename Y>
327
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
328 329 330 331
    y.device(d) = x.abs();
  }
};

332 333
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
334
  template <typename Device, typename X, typename Y, typename dY, typename dX>
335
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
336 337 338 339
    dx.device(d) = dy * x.sign();
  }
};

Q
qijun 已提交
340 341
// reciprocal(x) = 1 / x
template <typename T>
342
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
343
  template <typename Device, typename X, typename Y>
344
  void operator()(Device d, X x, Y y) const {
345
    y.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
346 347 348
  }
};

349
template <typename T>
350
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
351
  template <typename Device, typename X, typename Y, typename dY, typename dX>
352
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
353
    dx.device(d) = dy * static_cast<T>(-1) * y * y;
Q
qijun 已提交
354 355 356 357
  }
};

// log(x) = natural logarithm of x
358 359
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
360
  template <typename Device, typename X, typename Y>
361
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
362 363 364 365
    y.device(d) = x.log();
  }
};

366
template <typename T>
367
struct LogGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
368
  template <typename Device, typename X, typename Y, typename dY, typename dX>
369
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
370
    dx.device(d) = dy * (static_cast<T>(1) / x);
Q
qijun 已提交
371 372 373 374
  }
};

// square(x) = x^2
375 376
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
377
  template <typename Device, typename X, typename Y>
378
  void operator()(Device d, X x, Y y) const {
Q
qijun 已提交
379 380
    y.device(d) = x.square();
  }
381
};
Q
qijun 已提交
382

383
template <typename T>
384
struct SquareGradFunctor : public BaseActivationFunctor<T> {
Q
qijun 已提交
385
  template <typename Device, typename X, typename Y, typename dY, typename dX>
386
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
387 388 389 390
    dx.device(d) = dy * static_cast<T>(2) * x;
  }
};

391 392 393 394 395 396 397 398 399 400
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
401

402 403
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
404 405
    y.device(d) =
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
406 407 408
  }
};

409 410 411 412 413 414 415 416 417
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
418 419 420
    dx.device(d) = dy *
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
421 422 423
  }
};

424 425 426 427 428 429 430 431 432 433 434
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
435 436
    y.device(d) =
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
437 438 439 440 441 442 443 444 445 446 447
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
448 449 450
    dx.device(d) = dy *
                   ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
                       .template cast<T>();
451 452 453
  }
};

K
kexinzhao 已提交
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) {
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
    y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) {
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
    dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
  }
};

481 482
// softsign(x) = x / (1 + |x|)
template <typename T>
483
struct SoftsignFunctor : public BaseActivationFunctor<T> {
484 485 486 487 488 489 490 491 492
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) {
    y.device(d) = x / (static_cast<T>(1) + x.abs());
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
493
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
494 495 496 497 498 499 500
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) {
    dx.device(d) =
        dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
  }
};

501 502 503 504 505 506
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
507

508 509
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
510 511
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
512
    y.device(d) = (static_cast<T>(1) + temp.exp()).log();
513 514 515
  }
};

516 517 518 519 520 521 522 523
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
524 525
    auto tmp = static_cast<T>(threshold);
    auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
526
    dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
527 528 529
  }
};

K
Kavya Srinet 已提交
530 531 532 533 534 535
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
536

K
Kavya Srinet 已提交
537 538
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
539
    y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
540 541 542
  }
};

K
Kavya Srinet 已提交
543 544 545 546 547 548 549 550
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
551 552
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
553 554
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
    dx.device(d) = dy * (temp1 + temp2).template cast<T>();
555 556 557
  }
};

558 559 560 561 562 563
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
564

565 566
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
567 568 569
    y.device(d) = x.cwiseMax(static_cast<T>(0)) +
                  (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                      .cwiseMin(static_cast<T>(0));
570 571 572
  }
};

573 574 575 576 577 578 579 580
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
581 582 583
    dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() +
                   dy * (y + static_cast<T>(alpha)) *
                       (x < static_cast<T>(0)).template cast<T>();
584 585 586
  }
};

Q
QI JUN 已提交
587
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
588 589 590 591 592 593 594 595
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
596
    y.device(d) = x.pow(static_cast<T>(factor));
597 598 599
  }
};

600 601 602 603 604 605 606 607
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
608 609
    dx.device(d) = dy * static_cast<T>(factor) *
                   x.pow(static_cast<T>(factor - static_cast<T>(1)));
610 611 612
  }
};

613 614 615 616 617 618 619
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
620

621 622
  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
623 624
    y.device(d) =
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
625 626 627
  }
};

628 629 630 631 632 633 634
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
635

636 637
  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
638 639 640 641
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
    dx.device(d) = dy * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
642 643 644
  }
};

645 646 647 648 649 650 651 652 653
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
Y
Yu Yang 已提交
654 655
    auto th = static_cast<T>(threshold);
    y.device(d) = (x > th).template cast<T>() * x;
656 657 658 659 660 661 662 663 664 665 666 667
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
Y
Yu Yang 已提交
668 669
    auto th = static_cast<T>(threshold);
    dx.device(d) = dy * (x > th).template cast<T>();
670 671 672
  }
};

673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
    y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
    dx.device(d) =
        dy *
        ((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() *
        static_cast<T>(slope);
  }
};

A
Abhinav Arora 已提交
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  template <typename Device, typename X, typename Y>
  void operator()(Device d, X x, Y y) const {
    y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  template <typename Device, typename X, typename Y, typename dY, typename dX>
  void operator()(Device d, X x, Y y, dY dy, dX dx) const {
    auto temp1 = static_cast<T>(1) /
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
    auto temp2 = temp1 * (static_cast<T>(1) - (beta * y));
    dx.device(d) = dy * ((beta * y) + temp2);
  }
};

Q
qijun 已提交
734 735
}  // namespace operators
}  // namespace paddle
736

737 738 739 740 741 742 743 744 745
#define FOR_EACH_KERNEL_FUNCTOR(__macro)                             \
  __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);     \
  __macro(exp, ExpFunctor, ExpGradFunctor);                          \
  __macro(relu, ReluFunctor, ReluGradFunctor);                       \
  __macro(tanh, TanhFunctor, TanhGradFunctor);                       \
  __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor);     \
  __macro(sqrt, SqrtFunctor, SqrtGradFunctor);                       \
  __macro(abs, AbsFunctor, AbsGradFunctor);                          \
D
dzhwinter 已提交
746 747 748
  __macro(ceil, CeilFunctor, ZeroGradFunctor);                       \
  __macro(floor, FloorFunctor, ZeroGradFunctor);                     \
  __macro(round, RoundFunctor, ZeroGradFunctor);                     \
749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
  __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);     \
  __macro(log, LogFunctor, LogGradFunctor);                          \
  __macro(square, SquareFunctor, SquareGradFunctor);                 \
  __macro(brelu, BReluFunctor, BReluGradFunctor);                    \
  __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor);          \
  __macro(pow, PowFunctor, PowGradFunctor);                          \
  __macro(stanh, STanhFunctor, STanhGradFunctor);                    \
  __macro(softplus, SoftplusFunctor, SoftplusGradFunctor);           \
  __macro(softsign, SoftsignFunctor, SoftsignGradFunctor);           \
  __macro(relu6, Relu6Functor, Relu6GradFunctor);                    \
  __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor);       \
  __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor);    \
  __macro(elu, ELUFunctor, ELUGradFunctor);                          \
  __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor);    \
  __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
A
Abhinav Arora 已提交
764
  __macro(swish, SwishFunctor, SwishGradFunctor);                    \
765
  __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);