activation_op.h 27.4 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
Q
qijun 已提交
19

20 21 22 23
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

Q
qijun 已提交
24 25 26
namespace paddle {
namespace operators {

Q
QI JUN 已提交
27
template <typename DeviceContext, typename Functor>
28 29
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
30
 public:
31 32
  using T = typename Functor::ELEMENT_TYPE;

Q
qijun 已提交
33
  void Compute(const framework::ExecutionContext& context) const override {
Y
Update  
Yang Yu 已提交
34 35 36 37 38 39 40 41 42 43
    auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
                          "Cannot get input tensor X, variable name = %s",
                          context.op().Input("X"));

    auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
                            "Cannot get output tensor Out, variable name = %s",
                            context.op().Output("Out"));
    Out.mutable_data<T>(context.GetPlace());
    auto x = framework::EigenVector<T>::Flatten(X);
    auto out = framework::EigenVector<T>::Flatten(Out);
Q
QI JUN 已提交
44 45
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
46
    Functor functor;
47 48 49 50 51

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
52
    functor(*place, x, out);
Q
qijun 已提交
53 54 55
  }
};

Q
QI JUN 已提交
56
template <typename DeviceContext, typename Functor>
57 58
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
59
 public:
60
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
61 62
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
F
fengjiayi 已提交
63 64 65
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
Q
qijun 已提交
66 67 68
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    dX->mutable_data<T>(context.GetPlace());

F
fengjiayi 已提交
69
    auto dout = framework::EigenVector<T>::Flatten(*dOut);
Q
qijun 已提交
70
    auto x = framework::EigenVector<T>::Flatten(*X);
F
fengjiayi 已提交
71
    auto out = framework::EigenVector<T>::Flatten(*Out);
Q
qijun 已提交
72
    auto dx = framework::EigenVector<T>::Flatten(*dX);
Q
QI JUN 已提交
73 74
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
75
    Functor functor;
76 77 78 79
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
F
fengjiayi 已提交
80
    functor(*place, x, out, dout, dx);
Q
qijun 已提交
81 82 83
  }
};

84 85 86 87 88 89 90 91 92
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

93
// sigmoid(x) = 1 / (1 + exp(-x))
Q
qijun 已提交
94
template <typename T>
95
struct SigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
96 97 98
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
Q
qijun 已提交
99 100 101
  }
};

102
template <typename T>
103
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
104 105 106 107
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
Q
qijun 已提交
108 109 110
  }
};

111 112 113 114
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
F
fengjiayi 已提交
115
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
116 117 118 119 120 121 122 123 124 125
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
126 127
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
128
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
F
fengjiayi 已提交
129
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
130 131 132 133 134 135 136 137
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
138 139 140
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
141 142
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
F
fengjiayi 已提交
143
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
144 145 146
  }
};

Q
qijun 已提交
147
// exp(x) = e^x
148 149
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
150 151 152
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.exp();
Q
qijun 已提交
153 154 155
  }
};

156 157
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
158 159 160 161
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * out;
Q
qijun 已提交
162 163 164
  }
};

Q
qijun 已提交
165
// relu(x) = max(x, 0)
Q
qijun 已提交
166
template <typename T>
167
struct ReluFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
168 169 170
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
Q
qijun 已提交
171 172
  }
};
Q
qijun 已提交
173

Q
qijun 已提交
174
template <typename T>
175
struct ReluGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
176 177 178 179
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
Q
qijun 已提交
180 181
  }
};
Q
qijun 已提交
182

183
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
184 185
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
186 187 188
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
Q
qijun 已提交
189 190 191 192
  }
};

template <typename T>
193
struct TanhGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
194 195 196 197
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
Q
qijun 已提交
198 199 200
  }
};

K
Kavya Srinet 已提交
201 202 203 204
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
205 206 207
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
K
Kavya Srinet 已提交
208 209 210 211 212
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
213 214 215 216
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x.tanh() * x.tanh());
K
Kavya Srinet 已提交
217 218 219
  }
};

220 221 222 223 224 225 226 227 228
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
229 230
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
231 232
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
233
    out.device(d) = x * (temp1 + temp2);
234 235 236 237 238 239 240 241 242 243 244
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
245 246 247
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
248 249
    auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
    auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
F
fengjiayi 已提交
250
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
251 252 253
  }
};

K
Kexin Zhao 已提交
254
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
255 256 257 258 259 260 261 262
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

F
fengjiayi 已提交
263 264
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
265 266 267
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
268
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
269 270 271 272 273 274 275 276 277
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
F
fengjiayi 已提交
278 279 280
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
281 282 283
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>().eval();
    auto temp2 = (x < -lambdaT).template cast<T>().eval();
F
fengjiayi 已提交
284
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
285 286 287
  }
};

Q
qijun 已提交
288
// sqrt(x) = x^(1/2)
289 290
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
291 292 293
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
Q
qijun 已提交
294 295 296 297
  }
};

template <typename T>
298
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
299 300 301 302 303
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    const Out out_conj = Eigen::numext::conj(out);
    dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
Q
qijun 已提交
304 305 306
  }
};

D
dzhwinter 已提交
307 308 309
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
310 311 312
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
D
dzhwinter 已提交
313 314 315 316 317
  }
};

template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
318 319 320
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
321 322 323 324 325 326 327
    dx.device(d) = static_cast<T>(0) / x;
  }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
328 329
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Q
Qiao Longfei 已提交
330
    out.device(d) = x.floor();
D
dzhwinter 已提交
331 332 333 334 335 336
  }
};

// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
337 338 339
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
D
dzhwinter 已提交
340 341 342
  }
};

Q
qijun 已提交
343
// abs(x) = |x|
344 345
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
346 347 348
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.abs();
Q
qijun 已提交
349 350 351
  }
};

352 353
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
354 355 356 357
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * x.sign();
358 359 360
  }
};

Q
qijun 已提交
361 362
// reciprocal(x) = 1 / x
template <typename T>
363
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
364 365 366
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
Q
qijun 已提交
367 368 369
  }
};

370
template <typename T>
371
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
372 373 374 375
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
Q
qijun 已提交
376 377 378 379
  }
};

// log(x) = natural logarithm of x
380 381
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
382 383 384
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.log();
Q
qijun 已提交
385 386 387
  }
};

388
template <typename T>
389
struct LogGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
390 391 392 393
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (static_cast<T>(1) / x);
Q
qijun 已提交
394 395 396 397
  }
};

// square(x) = x^2
398 399
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
400 401 402
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
Q
qijun 已提交
403
  }
404
};
Q
qijun 已提交
405

406
template <typename T>
407
struct SquareGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
408 409 410 411
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(2) * x;
412 413 414
  }
};

415 416 417 418 419 420 421 422 423 424
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
425

F
fengjiayi 已提交
426 427 428
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
429
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
430 431 432
  }
};

433 434 435 436 437 438 439
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
F
fengjiayi 已提交
440 441 442 443
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
444 445
                   ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                       .template cast<T>();
446 447 448
  }
};

449 450 451 452 453 454 455 456 457
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
458 459 460
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
461
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
462 463 464 465 466 467 468 469 470
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
471 472 473 474
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
Y
Yu Yang 已提交
475 476
                   ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
                       .template cast<T>();
477 478 479
  }
};

K
kexinzhao 已提交
480 481 482 483 484 485 486
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
487 488
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
K
kexinzhao 已提交
489
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
490
    out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
K
kexinzhao 已提交
491 492 493 494 495 496 497 498 499
  }
};

// d(softplus(x))/dx = exp(x) / (1 + exp(x))
// For numerical stability:
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
500 501 502
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
K
kexinzhao 已提交
503
    auto temp = x.cwiseMax(static_cast<T>(0));  // temp = max(x, 0)
F
fengjiayi 已提交
504 505
    dx.device(d) =
        dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
K
kexinzhao 已提交
506 507 508
  }
};

509 510
// softsign(x) = x / (1 + |x|)
template <typename T>
511
struct SoftsignFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
512 513 514
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
515 516 517 518 519 520
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
521
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
F
fengjiayi 已提交
522 523 524
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) {
525
    dx.device(d) =
F
fengjiayi 已提交
526
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
527 528 529
  }
};

530 531 532 533 534 535
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
536

F
fengjiayi 已提交
537 538
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
539 540
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
541
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
542 543 544
  }
};

545 546 547 548 549 550
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
551 552 553
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
554 555
    auto tmp = static_cast<T>(threshold);
    auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
F
fengjiayi 已提交
556
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
557 558 559
  }
};

K
Kavya Srinet 已提交
560 561 562 563 564 565
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
566

F
fengjiayi 已提交
567 568 569
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
570 571 572
  }
};

K
Kavya Srinet 已提交
573 574 575 576 577 578
template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
579 580 581
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
582 583
    auto temp1 = static_cast<T>(alpha) *
                 (x < static_cast<T>(0)).template cast<T>().eval();
K
Kavya Srinet 已提交
584
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
F
fengjiayi 已提交
585
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
586 587 588
  }
};

589 590 591 592 593 594
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
595

F
fengjiayi 已提交
596 597 598 599 600
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0)) +
                    (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
                        .cwiseMin(static_cast<T>(0));
601 602 603
  }
};

604 605 606 607 608 609
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
F
fengjiayi 已提交
610 611 612 613 614
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
                   dout * (out + static_cast<T>(alpha)) *
Y
Yu Yang 已提交
615
                       (x < static_cast<T>(0)).template cast<T>();
616 617 618
  }
};

Q
QI JUN 已提交
619
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
620 621 622 623 624 625
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
626 627 628
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
629 630 631
  }
};

632 633 634 635 636 637
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
F
fengjiayi 已提交
638 639 640 641
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout * static_cast<T>(factor) *
Y
Yu Yang 已提交
642
                   x.pow(static_cast<T>(factor - static_cast<T>(1)));
643 644 645
  }
};

646 647 648 649 650 651 652
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
653

F
fengjiayi 已提交
654 655 656
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
657
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
658 659 660
  }
};

661 662 663 664 665 666 667
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }
668

F
fengjiayi 已提交
669 670 671
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
672 673 674
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
F
fengjiayi 已提交
675
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
Q
qijun 已提交
676 677 678
  }
};

679 680 681 682 683 684 685
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
686 687
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
688
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
689
    out.device(d) = (x > th).template cast<T>() * x;
690 691 692 693 694 695 696 697 698 699
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
700 701 702
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
703
    auto th = static_cast<T>(threshold);
F
fengjiayi 已提交
704
    dx.device(d) = dout * (x > th).template cast<T>();
705 706 707
  }
};

708 709 710 711 712 713 714 715
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
716 717
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
718
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
F
fengjiayi 已提交
719 720
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
721 722 723 724 725 726 727 728 729 730 731
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

F
fengjiayi 已提交
732 733 734 735 736 737 738
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
739 740 741
  }
};

A
Abhinav Arora 已提交
742 743 744 745 746 747 748
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
749 750 751
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
A
Abhinav Arora 已提交
752 753 754 755 756 757 758 759 760 761
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

F
fengjiayi 已提交
762 763 764
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
A
Abhinav Arora 已提交
765 766
    auto temp1 = static_cast<T>(1) /
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
F
fengjiayi 已提交
767 768
    auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
    dx.device(d) = dout * ((beta * out) + temp2);
A
Abhinav Arora 已提交
769 770 771
  }
};

Q
qijun 已提交
772 773
}  // namespace operators
}  // namespace paddle
774

775 776 777 778 779 780 781 782
#define FOR_EACH_KERNEL_FUNCTOR(__macro)                             \
  __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor);              \
  __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor);     \
  __macro(exp, ExpFunctor, ExpGradFunctor);                          \
  __macro(tanh, TanhFunctor, TanhGradFunctor);                       \
  __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor);     \
  __macro(sqrt, SqrtFunctor, SqrtGradFunctor);                       \
  __macro(abs, AbsFunctor, AbsGradFunctor);                          \
D
dzhwinter 已提交
783 784 785
  __macro(ceil, CeilFunctor, ZeroGradFunctor);                       \
  __macro(floor, FloorFunctor, ZeroGradFunctor);                     \
  __macro(round, RoundFunctor, ZeroGradFunctor);                     \
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
  __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor);     \
  __macro(log, LogFunctor, LogGradFunctor);                          \
  __macro(square, SquareFunctor, SquareGradFunctor);                 \
  __macro(brelu, BReluFunctor, BReluGradFunctor);                    \
  __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor);          \
  __macro(pow, PowFunctor, PowGradFunctor);                          \
  __macro(stanh, STanhFunctor, STanhGradFunctor);                    \
  __macro(softplus, SoftplusFunctor, SoftplusGradFunctor);           \
  __macro(softsign, SoftsignFunctor, SoftsignGradFunctor);           \
  __macro(relu6, Relu6Functor, Relu6GradFunctor);                    \
  __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor);       \
  __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor);    \
  __macro(elu, ELUFunctor, ELUGradFunctor);                          \
  __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor);    \
  __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
A
Abhinav Arora 已提交
801
  __macro(swish, SwishFunctor, SwishGradFunctor);                    \
802
  __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);