activation_op.cu 61.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/operators/math/math_cuda_utils.h"
16
#include "paddle/fluid/platform/bfloat16.h"
17
#include "paddle/fluid/platform/cuda_device_function.h"
18

19 20 21
namespace paddle {
namespace operators {

22 23 24 25 26
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // relu(x) = max(x, 0)
27 28
  __device__ __forceinline__ T operator()(const T& x) const {
    return x > zero ? x : zero;
29 30
  }
};
31 32

template <typename T>
33 34 35 36
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // dx = dout * (out > 0)
37 38
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return out > zero ? dout : zero;
39 40 41
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
42 43
};

44 45 46 47 48 49 50 51 52 53
template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // leakyrelu(x) = x > 0 ? x : alpha * x
54 55
  __device__ __forceinline__ T operator()(const T& x) const {
    return x > zero ? x : static_cast<T>(alpha) * x;
56
  }
57 58
};

59 60 61 62 63 64 65 66 67 68
template <typename T>
struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout * (x > 0 ? 1 : alpha)
69 70
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return x > zero ? dout : static_cast<T>(alpha) * dout;
71 72 73
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
74 75 76
};

template <typename T>
77 78 79 80 81
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
82 83
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
84 85 86
    return static_cast<T>(one / (one + exp(-x)));
  }
};
87

88 89 90 91 92
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * out * (1 - out)
93 94
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return dout * out * (one - out);
95
  }
96

97
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
98 99
};

100 101 102 103 104 105
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // silu(x) = x / (1 + exp(-x))
106 107
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
108 109 110
    return static_cast<T>(x / (one + exp(-x)));
  }
};
111 112

template <typename T>
113 114 115 116 117
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
118 119 120 121
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
122 123 124
    MPType temp = one / (one + exp(-x));
    return static_cast<T>(dout * (temp * (one + x * (one - temp))));
  }
125

126 127
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
128

129 130 131 132 133 134 135 136 137
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // logsigmoid(x) = log(1 / (1 + exp(-x)))
  // For numerical stability,
  // logsigmoid(x) =
  //          - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
138 139
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
140 141 142 143
    MPType temp = x > zero ? zero : -x;
    return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
  }
};
144 145

template <typename T>
146 147 148 149 150 151 152 153
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // dx = dout * exp(-x) / (1 + exp(-x))
  // For numerical stability:
  // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
  // 0)))
154 155 156 157
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
158 159 160 161
    MPType temp1 = x > zero ? zero : -x;
    MPType temp2 = exp(-x - temp1);
    return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
  }
162

163 164 165 166 167 168 169 170
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // atan(x) = atan(x)
171 172
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
173 174 175 176 177 178 179 180 181
    return static_cast<T>(atan(x));
  }
};

template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x^2)
182 183
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout / (one + x * x);
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // softshrink(x) = x - lambda, if x > lambda;
  //                 x + lambda, if x < -lambda;
  //                 0, otherwise.
200
  __device__ __forceinline__ T operator()(const T& x) const {
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    T l = static_cast<T>(lambda);
    T temp1 = static_cast<T>(x > l);
    T temp2 = static_cast<T>(x < -l);
    return temp1 * (x - l) + temp2 * (x + l);
  }
};

template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // dx = dout, if x > lambda or x < -lambda else 0
218
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
219
    T l = static_cast<T>(lambda);
220
    return (x >= -l && x <= l) ? zero : dout;
221 222 223 224 225 226 227 228 229 230
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
231 232
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
233 234 235 236 237 238 239 240 241
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
242 243
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
244 245 246 247 248 249 250 251 252
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // round(x) = round(x)
253 254
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
255 256 257 258
    return static_cast<T>(round(x));
  }
};

259
// GradFunctor for ceil, floor and round
260 261
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
262
  __device__ __forceinline__ T operator()(const T& x) const {
263 264 265 266 267 268 269 270 271 272 273
    return static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
};

template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cos(x) = cos(x)
274 275
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
276 277 278 279 280 281 282 283 284
    return static_cast<T>(cos(x));
  }
};

template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * (-sin(x))
285 286 287 288
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
289 290 291 292 293 294 295 296 297 298 299
    return static_cast<T>(-dout * sin(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sin(x) = sin(x)
300 301
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
302 303 304 305 306 307 308 309 310
    return static_cast<T>(sin(x));
  }
};

template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cos(x)
311 312 313 314
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
315 316 317 318 319 320 321 322 323 324 325
    return static_cast<T>(dout * cos(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tan(x) = tan(x)
326 327
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
328 329 330 331 332 333 334 335 336
    return static_cast<T>(tan(x));
  }
};

template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout / cos(x)^2
337 338 339 340
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
341 342 343 344 345 346 347 348 349 350 351
    return static_cast<T>(dout / (cos(x) * cos(x)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // asin(x) = asin(x)
352 353
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
354 355 356 357 358 359 360 361 362 363
    return static_cast<T>(asin(x));
  }
};

template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout / sqrt(1 - x^2)
364 365 366 367
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
368 369 370 371 372 373 374 375 376 377 378
    return static_cast<T>(dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // acos(x) = acos(x)
379 380
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
381 382 383 384 385 386 387 388 389 390
    return static_cast<T>(acos(x));
  }
};

template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = -dout / sqrt(1 - x^2)
391 392 393 394
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
395 396 397 398 399
    return static_cast<T>(-dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
400

401 402 403 404 405
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cosh(x) = cosh(x)
406 407
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
408 409 410 411 412 413 414 415 416
    return static_cast<T>(cosh(x));
  }
};

template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * sinh(x)
417 418 419 420
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
421 422 423 424 425 426 427 428 429 430 431
    return static_cast<T>(dout * sinh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sinh(x) = sinh(x)
432 433
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
434 435 436 437 438 439 440 441 442
    return static_cast<T>(sinh(x));
  }
};

template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cosh(x)
443 444 445 446
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
447 448 449 450 451 452 453 454 455 456 457
    return static_cast<T>(dout * cosh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanh(x) = tanh(x)
458 459
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
460
    return static_cast<T>(tanh(x));
461
  }
462
};
463

464 465 466 467 468
template <typename T>
struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * (1 - out^2)
469
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
470
    return dout * (one - out * out);
471 472 473 474 475
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

476 477 478 479 480
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // reciprocal(x) = 1 / x
481
  __device__ __forceinline__ T operator()(const T& x) const { return one / x; }
482
};
483

484
template <typename T>
485 486
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
487 488
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return -dout * out * out;
489
  }
490

491 492
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
493

494 495 496 497 498
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // exp(x) = exp(x)
499 500
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
501 502 503
    return static_cast<T>(exp(x));
  }
};
504 505

template <typename T>
506 507
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
508 509
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return dout * out;
510
  }
511

512 513
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
514

R
ronnywang 已提交
515 516 517 518 519
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // expm1(x) = expm1(x)
520 521
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
R
ronnywang 已提交
522 523 524 525 526 527 528
    return static_cast<T>(expm1(x));
  }
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
529 530
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return dout * out + dout;
R
ronnywang 已提交
531 532 533 534 535
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

536 537 538 539 540
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log(x) = log(x)
541 542
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
543 544 545 546 547 548 549
    return static_cast<T>(log(x));
  }
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout / x
550 551
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout / x;
552 553
  }

554 555 556 557 558 559
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
  // square(x) = x * x
560
  __device__ __forceinline__ T operator()(const T& x) const { return x * x; }
561
};
562

563 564 565 566 567
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
568 569
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout * two * x;
570 571 572 573 574
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

575 576 577 578 579
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sqrt(x) = sqrt(x)
580 581
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
582 583 584
    return static_cast<T>(sqrt(x));
  }
};
585

586 587 588 589 590
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
591 592
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return one_half * dout / out;
593 594 595 596
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
597

598 599 600 601 602
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // rsqrt(x) = rsqrt(x)
603 604
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
605 606 607 608 609 610 611 612
    return static_cast<T>(rsqrt(x));
  }
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
  T minus_one_half = static_cast<T>(-0.5f);

613 614 615
  // dx = -0.5 * dout * out^3
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return minus_one_half * dout * out * out * out;
616 617 618 619
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
620

621 622 623 624 625 626
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // log1p(x) = log(1 + x)
627 628
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
629 630 631 632 633 634 635 636 637
    return static_cast<T>(log(one + x));
  }
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x)
638 639
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout / (one + x);
640 641 642 643 644 645 646 647 648 649
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log2(x) = log2(x)
650 651
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
652 653 654 655 656 657 658 659 660 661
    return static_cast<T>(log2(x));
  }
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

  // dx = dout / (x * log(2))
662 663
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout / (x * log_two);
664 665 666 667 668 669 670 671 672 673
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log10(x) = log10(x)
674 675
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
676 677 678 679 680 681 682 683 684 685
    return static_cast<T>(log10(x));
  }
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

  // dx = dout / (x * log(10))
686 687
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return dout / (x * log_ten);
688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaBReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // brelu(x) = min(max(x, t_min), t_max)
703
  __device__ __forceinline__ T operator()(const T& x) const {
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    T temp_max = x > t_min_cast ? x : t_min_cast;
    T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast;
    return temp_min;
  }
};

template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // dx = (x > t_min && x < t_max) ? dout : 0
723
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    return (x > t_min_cast && x < t_max_cast) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
744 745
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
765 766 767 768
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_out) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // stanh(x) = b * tanh(a * x)
788 789
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    return static_cast<T>(b * tanh(a * x));
  }
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
808 809 810 811
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
833 834
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
  }
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
854 855 856 857
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
858 859 860
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
861
    return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
862 863 864 865 866 867 868 869 870 871
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // softsign(x) = x / (1 + abs(x))
872 873
  __device__ __forceinline__ T operator()(const T& x) const {
    return x / (one + abs(x));
874 875 876 877 878 879 880 881
  }
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
882 883 884
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    T temp = one + abs(x);
    return dout / (temp * temp);
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
900
  __device__ __forceinline__ T operator()(const T& x) const {
901
    T t = static_cast<T>(threshold);
902
    return x <= zero ? zero : (x < t ? x : t);
903 904 905 906 907 908 909 910 911 912 913 914 915
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
916
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
917
    T t = static_cast<T>(threshold);
918
    return (out > zero && out < t) ? dout : zero;
919 920 921 922 923 924 925 926 927 928
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanhshrink(x) = x - tanh(x)
929 930
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
931 932 933 934 935 936 937 938 939
    return static_cast<T>(x - tanh(x));
  }
};

template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * tanh(x)^2
940 941 942 943
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
    return static_cast<T>(dout * tanh(x) * tanh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
960
  __device__ __forceinline__ T operator()(const T& x) const {
961 962 963 964 965 966 967 968 969 970 971 972 973 974 975
    T t = static_cast<T>(threshold);
    return (x > -t && x < t) ? zero : x;
  }
};

template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (x > -threshold && x < threshold) ? 0 : dout
976
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
977
    T t = static_cast<T>(threshold);
978
    return (x > -t && x < t) ? zero : dout;
979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // hard_sigmoid(x) = 0, when x <= -3
  //                   1, when x >= 3
  //                   x * slope + offset, otherwise
998 999
  __device__ __forceinline__ T operator()(const T& x) const {
    T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < one ? temp_max : one;
    return temp_min;
  }
};

template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // dx = (out > 0 && out < 1) ? dout * slope : 0
1018 1019
  __device__ __forceinline__ T operator()(const T& dout, const T& out) const {
    return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
1036 1037
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
1054 1055 1056 1057
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // thresholded_relu(x) = x > threshold ? x : 0
1079 1080
  __device__ __forceinline__ T operator()(const T& x) const {
    return x > static_cast<T>(threshold) ? x : zero;
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
  }
};

template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = x > threshold ? dout : 0
1094 1095
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
    return x > static_cast<T>(threshold) ? dout : zero;
1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
1116
  __device__ __forceinline__ T operator()(const T& x) const {
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
    T t = static_cast<T>(threshold);
    T temp = x + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < t ? temp_max : t;
    return temp_min * x / static_cast<T>(scale);
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  T two = static_cast<T>(2.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
1142
  __device__ __forceinline__ T operator()(const T& dout, const T& x) const {
1143 1144 1145 1146
    T o = static_cast<T>(offset);
    T s = static_cast<T>(scale);
    T temp1 = static_cast<T>(x + o > zero);
    T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
1147
    return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
1165 1166
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    CT x = static_cast<CT>(arg_x);
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
    CT temp = static_cast<CT>(alpha) * (exp(x) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
  // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
  // dx = 0, if alpha <= 0 and x <=0
1188 1189 1190 1191
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
                temp_a_neg * temp_x_pos * (one + a * exp(x))));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
  __device__ __forceinline__ T operator()(const T& arg_x) const {
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
  __device__ __forceinline__ T operator()(const T& arg_dout,
                                          const T& arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1258
template <typename DeviceContext, typename Functor>
1259
class ActivationCudaKernel
1260 1261 1262
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1263 1264
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
1265
    framework::Tensor* out = nullptr;
1266 1267 1268 1269 1270 1271
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
1272 1273
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
1274
      *attr.second = ctx.Attr<float>(attr.first);
1275
    }
1276 1277
    LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
        dev_ctx, ins, &outs, functor);
1278 1279 1280 1281
  }
};

template <typename DeviceContext, typename Functor>
1282
class ActivationGradCudaKernel
1283 1284 1285
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1286
  void Compute(const framework::ExecutionContext& ctx) const override {
1287 1288 1289
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
1290
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
1291
                                                    &d_x);
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
1302 1303 1304

    if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
      // Only need forward output Out
1305
      ins.push_back(out);
1306
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1307
          dev_ctx, ins, &outs, functor);
1308 1309 1310
    } else if (static_cast<int>(Functor::FwdDeps()) ==
               static_cast<int>(kDepX)) {
      // Only need forward input X
1311
      ins.push_back(x);
1312
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1313
          dev_ctx, ins, &outs, functor);
1314
    } else {
1315
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
1316
          dev_ctx, ins, &outs, functor);
1317 1318 1319 1320 1321 1322 1323
    }
  }
};

}  // namespace operators
}  // namespace paddle

1324
namespace ops = paddle::operators;
1325 1326
namespace plat = paddle::platform;

1327 1328
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
1329
  REGISTER_OP_CUDA_KERNEL(                                                     \
1330 1331 1332 1333 1334 1335
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::float16>>);                 \
1336
  REGISTER_OP_CUDA_KERNEL(                                                     \
1337 1338 1339 1340 1341 1342 1343
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::float16>>);
1344

1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::float16>>);                 \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::float16>>);

1371
/* ======================== leaky relu register  ============================ */
1372 1373
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
                                CudaLeakyReluGradFunctor);
1374 1375 1376 1377 1378 1379 1380 1381 1382

REGISTER_OP_CUDA_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
1383
/* ========================================================================== */
1384

D
Double_V 已提交
1385
/* ======================== elu register  ============================ */
1386
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
D
Double_V 已提交
1387 1388 1389 1390 1391 1392 1393 1394 1395 1396

REGISTER_OP_CUDA_KERNEL(
    elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::ELUGradGradFunctor<float>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<double>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409
/* ======================== celu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
                                CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                                              ops::CELUGradGradFunctor<float>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<double>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1410
/* ===========================    relu register  ============================ */
1411
#ifdef PADDLE_WITH_HIP
1412 1413
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
                                CudaReluGradFunctor);
1414 1415 1416 1417 1418 1419 1420 1421
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451
#else
REGISTER_OP_CUDA_KERNEL(
    relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                                    ops::CudaReluFunctor<float>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaReluFunctor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::float16>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                             ops::CudaReluGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::float16>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
1452 1453
/* ========================================================================== */

1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
/* ===========================    sigmoid register  ============================
 */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
                                CudaSigmoidGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    sigmoid_grad_grad,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<float>>,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<double>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<plat::float16>>);
1467 1468 1469 1470 1471 1472 1473 1474 1475

REGISTER_OP_CUDA_KERNEL(
    sigmoid_triple_grad,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<float>>,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<double>>,
    ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<plat::float16>>);
1476 1477
/* ========================================================================== */

1478
/* ===========================    tanh register  ============================ */
1479 1480
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, CudaTanhFunctor,
                                CudaTanhGradFunctor);
1481 1482 1483 1484 1485 1486 1487 1488 1489

REGISTER_OP_CUDA_KERNEL(
    tanh_grad_grad,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<float>>,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<double>>,
    ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
                              ops::TanhGradGradFunctor<plat::float16>>);
1490 1491 1492 1493 1494 1495 1496 1497 1498

REGISTER_OP_CUDA_KERNEL(
    tanh_triple_grad,
    ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<float>>,
    ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<double>>,
    ops::TanhTripeGradKernel<plat::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<plat::float16>>);
1499 1500
/* ========================================================================== */

L
lvmengsi 已提交
1501
/* ===========================   sqrt register  ============================= */
1502 1503
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                                CudaSqrtGradFunctor);
L
lvmengsi 已提交
1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

W
whs 已提交
1515 1516
/* ===========================   rsqrt register  =============================
 */
1517 1518
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
                                CudaRsqrtGradFunctor);
W
whs 已提交
1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1530
/* ===========================  square register  ============================ */
1531 1532
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
                                    CudaSquareGradFunctor);
1533 1534 1535 1536 1537 1538 1539 1540

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
1541 1542 1543 1544 1545
                                ops::SquareGradGradFunctor<plat::float16>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
1546
/* ========================================================================== */
1547 1548 1549 1550 1551

/* ==========================   pow register  ============================ */
REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
1552 1553
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
1554 1555 1556 1557 1558
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
1559 1560
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
1561 1562 1563
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
1564 1565 1566

/* ==========================   exp register  ============================ */
REGISTER_OP_CUDA_KERNEL(
1567 1568 1569 1570
    exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                   ops::CudaExpFunctor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<double>>,
1571 1572
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
1573 1574
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<plat::float16>>);
1575
REGISTER_OP_CUDA_KERNEL(
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585
    exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                            ops::CudaExpGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<plat::float16>>);
1586 1587
/* ========================================================================== */

R
ronnywang 已提交
1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605
/* ==========================   expm1 register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                     ops::CudaExpm1Functor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                              ops::CudaExpm1GradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

1606
/* ==========================  Log register ==================================*/
1607
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
1608 1609 1610 1611 1612 1613 1614 1615 1616

REGISTER_OP_CUDA_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1617

1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor);                  \
  __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,                      \
          CudaLogSigmoidGradFunctor);                                         \
  __macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor);                  \
  __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor,                      \
          CudaSoftShrinkGradFunctor);                                         \
  __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor);                  \
  __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor);               \
  __macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor);                      \
  __macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor);                      \
  __macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor);                  \
  __macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor);                      \
  __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor);                  \
  __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor);                  \
  __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor);                  \
  __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor);               \
  __macro(reciprocal, Reciprocal, CudaReciprocalFunctor,                      \
          CudaReciprocalGradFunctor);                                         \
  __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor);              \
  __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor);                  \
  __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor);              \
  __macro(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor);              \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor);              \
  __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor);  \
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);  \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
  __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor,                     \
          CudaTanhShrinkGradFunctor);                                         \
  __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor,                     \
          CudaHardShrinkGradFunctor);                                         \
  __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor,                  \
          CudaHardSigmoidGradFunctor);                                        \
  __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor);              \
  __macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor,      \
          CudaThresholdedReluGradFunctor);                                    \
  __macro(hard_swish, HardSwish, CudaHardSwishFunctor,                        \
          CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)