activation_op.kps 73.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/platform/bfloat16.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
17

18 19 20
namespace paddle {
namespace operators {

21 22 23 24 25
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // relu(x) = max(x, 0)
26
  __device__ __forceinline__ T operator()(const T x) const {
27
    return x > zero ? x : zero;
28 29
  }
};
30 31

template <typename T>
32 33 34 35
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // dx = dout * (out > 0)
36
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
37
    return out > zero ? dout : zero;
38 39 40
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
41 42
};

43 44 45 46 47 48 49 50 51 52
template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // leakyrelu(x) = x > 0 ? x : alpha * x
53
  __device__ __forceinline__ T operator()(const T x) const {
54
    return x > zero ? x : static_cast<T>(alpha) * x;
55
  }
56 57
};

58 59 60 61 62 63 64 65 66 67
template <typename T>
struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout * (x > 0 ? 1 : alpha)
68
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
69
    return x > zero ? dout : static_cast<T>(alpha) * dout;
70 71 72
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
73 74 75
};

template <typename T>
76 77 78 79 80
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
81
  __device__ __forceinline__ T operator()(const T arg_x) const {
82
    MPType x = static_cast<MPType>(arg_x);
83 84 85
    return static_cast<T>(one / (one + exp(-x)));
  }
};
86

87 88 89 90 91
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * out * (1 - out)
92
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
93
    return dout * out * (one - out);
94
  }
95

96
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
97 98
};

99 100 101 102 103 104
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // silu(x) = x / (1 + exp(-x))
105
  __device__ __forceinline__ T operator()(const T arg_x) const {
106
    MPType x = static_cast<MPType>(arg_x);
107 108 109
    return static_cast<T>(x / (one + exp(-x)));
  }
};
110 111

template <typename T>
112 113 114 115 116
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
117 118
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
119 120
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
121 122 123
    MPType temp = one / (one + exp(-x));
    return static_cast<T>(dout * (temp * (one + x * (one - temp))));
  }
124

125 126
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
127

128 129 130 131 132 133 134 135 136
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // logsigmoid(x) = log(1 / (1 + exp(-x)))
  // For numerical stability,
  // logsigmoid(x) =
  //          - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
137
  __device__ __forceinline__ T operator()(const T arg_x) const {
138
    MPType x = static_cast<MPType>(arg_x);
139 140 141 142
    MPType temp = x > zero ? zero : -x;
    return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
  }
};
143 144

template <typename T>
145 146 147 148 149 150 151 152
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // dx = dout * exp(-x) / (1 + exp(-x))
  // For numerical stability:
  // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
  // 0)))
153 154
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
155 156
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
157 158 159 160
    MPType temp1 = x > zero ? zero : -x;
    MPType temp2 = exp(-x - temp1);
    return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
  }
161

162 163 164 165 166 167 168 169
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // atan(x) = atan(x)
170
  __device__ __forceinline__ T operator()(const T arg_x) const {
171
    MPType x = static_cast<MPType>(arg_x);
172 173 174 175 176 177 178 179 180
    return static_cast<T>(atan(x));
  }
};

template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x^2)
181
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
182
    return dout / (one + x * x);
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // softshrink(x) = x - lambda, if x > lambda;
  //                 x + lambda, if x < -lambda;
  //                 0, otherwise.
199
  __device__ __forceinline__ T operator()(const T x) const {
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    T l = static_cast<T>(lambda);
    T temp1 = static_cast<T>(x > l);
    T temp2 = static_cast<T>(x < -l);
    return temp1 * (x - l) + temp2 * (x + l);
  }
};

template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // dx = dout, if x > lambda or x < -lambda else 0
217
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
218
    T l = static_cast<T>(lambda);
219
    return (x >= -l && x <= l) ? zero : dout;
220 221 222 223 224 225 226 227 228 229
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
230
  __device__ __forceinline__ T operator()(const T arg_x) const {
231
    MPType x = static_cast<MPType>(arg_x);
232 233 234 235 236 237 238 239 240
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
241
  __device__ __forceinline__ T operator()(const T arg_x) const {
242
    MPType x = static_cast<MPType>(arg_x);
243 244 245 246 247 248 249 250 251
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // round(x) = round(x)
252
  __device__ __forceinline__ T operator()(const T arg_x) const {
253
    MPType x = static_cast<MPType>(arg_x);
254 255 256 257
    return static_cast<T>(round(x));
  }
};

258
// GradFunctor for ceil, floor and round
259 260
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
261
  __device__ __forceinline__ T operator()(const T x) const {
262 263 264 265 266 267 268 269 270 271 272
    return static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
};

template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cos(x) = cos(x)
273
  __device__ __forceinline__ T operator()(const T arg_x) const {
274
    MPType x = static_cast<MPType>(arg_x);
275 276 277 278 279 280 281 282 283
    return static_cast<T>(cos(x));
  }
};

template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * (-sin(x))
284 285
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
286 287
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
288 289 290 291 292 293 294 295 296 297 298
    return static_cast<T>(-dout * sin(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sin(x) = sin(x)
299
  __device__ __forceinline__ T operator()(const T arg_x) const {
300
    MPType x = static_cast<MPType>(arg_x);
301 302 303 304 305 306 307 308 309
    return static_cast<T>(sin(x));
  }
};

template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cos(x)
310 311
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
312 313
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
314 315 316 317 318 319 320 321 322 323 324
    return static_cast<T>(dout * cos(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tan(x) = tan(x)
325
  __device__ __forceinline__ T operator()(const T arg_x) const {
326
    MPType x = static_cast<MPType>(arg_x);
327 328 329 330 331 332 333 334 335
    return static_cast<T>(tan(x));
  }
};

template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout / cos(x)^2
336 337
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
338 339
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
340 341 342 343 344 345 346 347 348 349 350
    return static_cast<T>(dout / (cos(x) * cos(x)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // asin(x) = asin(x)
351
  __device__ __forceinline__ T operator()(const T arg_x) const {
352
    MPType x = static_cast<MPType>(arg_x);
353 354 355 356 357 358 359 360 361 362
    return static_cast<T>(asin(x));
  }
};

template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout / sqrt(1 - x^2)
363 364
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
365 366
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
367 368 369 370 371 372 373 374 375 376 377
    return static_cast<T>(dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // acos(x) = acos(x)
378
  __device__ __forceinline__ T operator()(const T arg_x) const {
379
    MPType x = static_cast<MPType>(arg_x);
380 381 382 383 384 385 386 387 388 389
    return static_cast<T>(acos(x));
  }
};

template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = -dout / sqrt(1 - x^2)
390 391
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
392 393
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
394 395 396 397 398
    return static_cast<T>(-dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
399

400 401 402 403 404
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cosh(x) = cosh(x)
405
  __device__ __forceinline__ T operator()(const T arg_x) const {
406
    MPType x = static_cast<MPType>(arg_x);
407 408 409 410 411 412 413 414 415
    return static_cast<T>(cosh(x));
  }
};

template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * sinh(x)
416 417
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
418 419
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
420 421 422 423 424 425 426 427 428 429 430
    return static_cast<T>(dout * sinh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sinh(x) = sinh(x)
431
  __device__ __forceinline__ T operator()(const T arg_x) const {
432
    MPType x = static_cast<MPType>(arg_x);
433 434 435 436 437 438 439 440 441
    return static_cast<T>(sinh(x));
  }
};

template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cosh(x)
442 443
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
444 445
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
446 447 448 449 450 451 452 453 454 455 456
    return static_cast<T>(dout * cosh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanh(x) = tanh(x)
457
  __device__ __forceinline__ T operator()(const T arg_x) const {
458
    MPType x = static_cast<MPType>(arg_x);
459
    return static_cast<T>(tanh(x));
460
  }
461
};
462

463 464 465 466 467
template <typename T>
struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * (1 - out^2)
468
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
469
    return dout * (one - out * out);
470 471 472 473 474
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

X
xiaoting 已提交
475 476 477 478 479
template <typename T>
struct CudaAcoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // Acosh(x) = acosh(x)
480
  __device__ __forceinline__ T operator()(const T arg_x) const {
X
xiaoting 已提交
481 482 483 484 485 486 487 488 489 490
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(acosh(x));
  }
};

template <typename T>
struct CudaAcoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  // dx = dout * 1 / sqrt(x^2 - 1)
491 492
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
X
xiaoting 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / sqrt(x * x - one));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAsinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // Asinh(x) = asinh(x)
506
  __device__ __forceinline__ T operator()(const T arg_x) const {
X
xiaoting 已提交
507 508 509 510 511 512 513 514 515 516 517
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(asinh(x));
  }
};

template <typename T>
struct CudaAsinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * 1/sqrt(x^2 + 1)
518 519
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
X
xiaoting 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / sqrt(x * x + one));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // Atanh(x) = atanh(x)
533
  __device__ __forceinline__ T operator()(const T arg_x) const {
X
xiaoting 已提交
534 535 536 537 538 539 540 541 542 543
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(atanh(x));
  }
};

template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  // dx = dout * 1/(1- x^2)
544 545
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
X
xiaoting 已提交
546 547 548 549 550 551 552 553
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / (one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

554 555 556 557 558
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // reciprocal(x) = 1 / x
559
  __device__ __forceinline__ T operator()(const T x) const { return one / x; }
560
};
561

562
template <typename T>
563 564
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
565
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
566
    return -dout * out * out;
567
  }
568

569 570
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
571

572 573 574 575 576
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // exp(x) = exp(x)
577
  __device__ __forceinline__ T operator()(const T arg_x) const {
578
    MPType x = static_cast<MPType>(arg_x);
579 580 581
    return static_cast<T>(exp(x));
  }
};
582 583

template <typename T>
584 585
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
586
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
587
    return dout * out;
588
  }
589

590 591
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
592

R
ronnywang 已提交
593 594 595 596 597
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // expm1(x) = expm1(x)
598
  __device__ __forceinline__ T operator()(const T arg_x) const {
599
    MPType x = static_cast<MPType>(arg_x);
R
ronnywang 已提交
600 601 602 603 604 605 606
    return static_cast<T>(expm1(x));
  }
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
607
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
608
    return dout * out + dout;
R
ronnywang 已提交
609 610 611 612 613
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

614 615 616 617 618
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log(x) = log(x)
619
  __device__ __forceinline__ T operator()(const T arg_x) const {
620
    MPType x = static_cast<MPType>(arg_x);
621 622 623 624 625 626 627
    return static_cast<T>(log(x));
  }
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout / x
628
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
629
    return dout / x;
630 631
  }

632 633 634 635 636 637
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
  // square(x) = x * x
638
  __device__ __forceinline__ T operator()(const T x) const { return x * x; }
639
};
640

641 642 643 644 645
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
646
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
647
    return dout * two * x;
648 649 650 651 652
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

653 654 655 656 657
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sqrt(x) = sqrt(x)
658
  __device__ __forceinline__ T operator()(const T arg_x) const {
659
    MPType x = static_cast<MPType>(arg_x);
660 661 662
    return static_cast<T>(sqrt(x));
  }
};
663

664 665 666 667 668
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
669
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
670
    return one_half * dout / out;
671 672 673 674
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
675

676 677 678 679 680
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // rsqrt(x) = rsqrt(x)
681
  __device__ __forceinline__ T operator()(const T arg_x) const {
682
    MPType x = static_cast<MPType>(arg_x);
683 684 685 686 687 688 689 690
    return static_cast<T>(rsqrt(x));
  }
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
  T minus_one_half = static_cast<T>(-0.5f);

691
  // dx = -0.5 * dout * out^3
692
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
693
    return minus_one_half * dout * out * out * out;
694 695 696 697
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
698

699 700 701 702 703 704
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // log1p(x) = log(1 + x)
705
  __device__ __forceinline__ T operator()(const T arg_x) const {
706
    MPType x = static_cast<MPType>(arg_x);
707 708 709 710 711 712 713 714 715
    return static_cast<T>(log(one + x));
  }
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x)
716
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
717
    return dout / (one + x);
718 719 720 721 722 723 724 725 726 727
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log2(x) = log2(x)
728
  __device__ __forceinline__ T operator()(const T arg_x) const {
729
    MPType x = static_cast<MPType>(arg_x);
730 731 732 733 734 735 736 737 738 739
    return static_cast<T>(log2(x));
  }
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

  // dx = dout / (x * log(2))
740
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
741
    return dout / (x * log_two);
742 743 744 745 746 747 748 749 750 751
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log10(x) = log10(x)
752
  __device__ __forceinline__ T operator()(const T arg_x) const {
753
    MPType x = static_cast<MPType>(arg_x);
754 755 756 757 758 759 760 761 762 763
    return static_cast<T>(log10(x));
  }
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

  // dx = dout / (x * log(10))
764
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
765
    return dout / (x * log_ten);
766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaBReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // brelu(x) = min(max(x, t_min), t_max)
781
  __device__ __forceinline__ T operator()(const T x) const {
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    T temp_max = x > t_min_cast ? x : t_min_cast;
    T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast;
    return temp_min;
  }
};

template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // dx = (x > t_min && x < t_max) ? dout : 0
801
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    return (x > t_min_cast && x < t_max_cast) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
822
  __device__ __forceinline__ T operator()(const T arg_x) const {
823
    MPType x = static_cast<MPType>(arg_x);
824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
843 844
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
845 846
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // stanh(x) = b * tanh(a * x)
866
  __device__ __forceinline__ T operator()(const T arg_x) const {
867
    MPType x = static_cast<MPType>(arg_x);
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    return static_cast<T>(b * tanh(a * x));
  }
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
886 887
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
888 889
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
911
  __device__ __forceinline__ T operator()(const T arg_x) const {
912
    MPType x = static_cast<MPType>(arg_x);
913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
  }
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
932 933
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
934 935
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
936 937 938
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
939
    return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
940 941 942 943 944 945 946 947 948 949
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // softsign(x) = x / (1 + abs(x))
950
  __device__ __forceinline__ T operator()(const T x) const {
951
    return x / (one + abs(x));
952 953 954 955 956 957 958 959
  }
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
960
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
961 962
    T temp = one + abs(x);
    return dout / (temp * temp);
963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
978
  __device__ __forceinline__ T operator()(const T x) const {
979
    T t = static_cast<T>(threshold);
980
    return x <= zero ? zero : (x < t ? x : t);
981 982 983 984 985 986 987 988 989 990 991 992 993
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
994
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
995
    T t = static_cast<T>(threshold);
996
    return (out > zero && out < t) ? dout : zero;
997 998 999 1000 1001 1002 1003 1004 1005 1006
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanhshrink(x) = x - tanh(x)
1007
  __device__ __forceinline__ T operator()(const T arg_x) const {
1008
    MPType x = static_cast<MPType>(arg_x);
1009 1010 1011 1012 1013 1014 1015 1016 1017
    return static_cast<T>(x - tanh(x));
  }
};

template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * tanh(x)^2
1018 1019
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
1020 1021
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037
    return static_cast<T>(dout * tanh(x) * tanh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
1038
  __device__ __forceinline__ T operator()(const T x) const {
1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
    T t = static_cast<T>(threshold);
    return (x > -t && x < t) ? zero : x;
  }
};

template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (x > -threshold && x < threshold) ? 0 : dout
1054
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
1055
    T t = static_cast<T>(threshold);
1056
    return (x > -t && x < t) ? zero : dout;
1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // hard_sigmoid(x) = 0, when x <= -3
  //                   1, when x >= 3
  //                   x * slope + offset, otherwise
1076
  __device__ __forceinline__ T operator()(const T x) const {
1077
    T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < one ? temp_max : one;
    return temp_min;
  }
};

template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // dx = (out > 0 && out < 1) ? dout * slope : 0
1096
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
1097
    return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
1114
  __device__ __forceinline__ T operator()(const T arg_x) const {
1115
    MPType x = static_cast<MPType>(arg_x);
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
1132 1133
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
1134 1135
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // mish(x) = x * tanh(softplus(x))
  // softplus(x) = x, if x > threshold
  //             = ln(1 + exp(x)), otherwise
  // Inputs: args[0], the input x
1161
  __device__ __forceinline__ T operator()(const T arg_x) const {
1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    return static_cast<T>(x * tanh(sp));
  }
};

template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
  // sp = softplus(x)
  // Inputs: args[0], the input dout
  //         args[1], the input x
1182 1183
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    MPType gsp =
        (x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
    MPType tsp = tanh(sp);
    return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // thresholded_relu(x) = x > threshold ? x : 0
1206
  __device__ __forceinline__ T operator()(const T x) const {
1207
    return x > static_cast<T>(threshold) ? x : zero;
1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220
  }
};

template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = x > threshold ? dout : 0
1221
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
1222
    return x > static_cast<T>(threshold) ? dout : zero;
1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
1243
  __device__ __forceinline__ T operator()(const T x) const {
1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268
    T t = static_cast<T>(threshold);
    T temp = x + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < t ? temp_max : t;
    return temp_min * x / static_cast<T>(scale);
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  T two = static_cast<T>(2.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
1269
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
1270 1271 1272 1273
    T o = static_cast<T>(offset);
    T s = static_cast<T>(scale);
    T temp1 = static_cast<T>(x + o > zero);
    T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
1274
    return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

Z
zhupengyang 已提交
1291 1292
  // elu(x) = x, if x > 0
  // elu(x) = alpha * (e^x - 1), if x <= 0
1293
  __device__ __forceinline__ T operator()(const T arg_x) const {
1294
    CT x = static_cast<CT>(arg_x);
1295
    CT temp = static_cast<CT>(alpha) * (exp(x) - one);
Z
zhupengyang 已提交
1296
    CT res = x > zero ? x : temp;
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

Z
zhupengyang 已提交
1311 1312 1313
  // case 1: alpha >= 0
  // dx = dout, if out > 0
  // dx = dout * (out + alpha), if out <= 0
1314
  __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const {
Z
zhupengyang 已提交
1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
    MPType a = static_cast<MPType>(alpha);
    MPType out_pos = static_cast<MPType>(out > zero);
    MPType out_neg = static_cast<MPType>(out <= zero);
    return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // case 2: alpha < 0
  // dx = dout, if x > 0
  // dx = dout * (out + alpha), if x <=0
1339 1340
  __device__ __forceinline__ T operator()(const T arg_dout, const T arg_out,
                                          const T arg_x) const {
1341
    MPType dout = static_cast<MPType>(arg_dout);
Z
zhupengyang 已提交
1342
    MPType out = static_cast<MPType>(arg_out);
1343
    MPType x = static_cast<MPType>(arg_x);
1344
    MPType a = static_cast<MPType>(alpha);
Z
zhupengyang 已提交
1345 1346 1347
    MPType x_pos = static_cast<MPType>(x > zero);
    MPType x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
1348 1349 1350 1351 1352
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

Z
zhupengyang 已提交
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369
template <typename DeviceContext, typename T>
class ELUGradCudaKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
    auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* out = ctx.Input<framework::Tensor>("Out");
    auto* x = ctx.Input<framework::Tensor>("X");
    auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    d_x->mutable_data<T>(ctx.GetPlace());
    const float alpha = ctx.Attr<float>("alpha");

    auto& dev_ctx = ctx.device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {d_out, out};
    std::vector<framework::Tensor*> outs = {d_x};
    if (alpha > 0) {
      CudaELUGradFunctor<T> functor;
      functor.alpha = alpha;
1370 1371
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
Z
zhupengyang 已提交
1372 1373 1374 1375
    } else {
      CudaELUGradNegativeAlphaFunctor<T> functor;
      functor.alpha = alpha;
      ins.push_back(x);
1376 1377
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
Z
zhupengyang 已提交
1378 1379 1380 1381
    }
  }
};

1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
1394
  __device__ __forceinline__ T operator()(const T arg_x) const {
1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
1417 1418
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1435
template <typename DeviceContext, typename Functor>
1436
class ActivationCudaKernel
1437 1438 1439
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1440 1441
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
1442
    framework::Tensor* out = nullptr;
1443 1444 1445 1446 1447 1448
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
1449 1450
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
1451
      *attr.second = ctx.Attr<float>(attr.first);
1452
    }
1453 1454
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
1455 1456 1457 1458
  }
};

template <typename DeviceContext, typename Functor>
1459
class ActivationGradCudaKernel
1460 1461 1462
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1463
  void Compute(const framework::ExecutionContext& ctx) const override {
1464 1465 1466
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
1467
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
1468
                                                    &d_x);
1469 1470 1471 1472 1473 1474 1475 1476 1477 1478
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
1479 1480 1481

    if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
      // Only need forward output Out
1482
      ins.push_back(out);
1483 1484
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
1485 1486 1487
    } else if (static_cast<int>(Functor::FwdDeps()) ==
               static_cast<int>(kDepX)) {
      // Only need forward input X
1488
      ins.push_back(x);
1489 1490
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
1491
    } else {
1492 1493
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
1494 1495 1496 1497 1498 1499 1500
    }
  }
};

}  // namespace operators
}  // namespace paddle

1501
namespace ops = paddle::operators;
1502 1503
namespace plat = paddle::platform;

1504 1505
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
1506
  REGISTER_OP_CUDA_KERNEL(                                                     \
1507 1508 1509 1510 1511
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
1512 1513 1514
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
1515
  REGISTER_OP_CUDA_KERNEL(                                                     \
1516 1517 1518 1519 1520 1521
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
1522 1523 1524
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
1525

1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
1538 1539 1540
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
1552 1553 1554
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
1555

1556
/* ======================== leaky relu register  ============================ */
1557 1558
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
                                CudaLeakyReluGradFunctor);
1559 1560 1561 1562 1563 1564 1565 1566 1567

REGISTER_OP_CUDA_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
1568
/* ========================================================================== */
1569

D
Double_V 已提交
1570
/* ======================== elu register  ============================ */
Z
zhupengyang 已提交
1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581
REGISTER_OP_CUDA_KERNEL(
    elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                                   ops::CudaELUFunctor<float>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaELUFunctor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaELUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
    ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
    ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);
D
Double_V 已提交
1582 1583 1584 1585 1586 1587 1588 1589 1590 1591

REGISTER_OP_CUDA_KERNEL(
    elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::ELUGradGradFunctor<float>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<double>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604
/* ======================== celu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
                                CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                                              ops::CELUGradGradFunctor<float>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<double>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1605
/* ===========================    relu register  ============================ */
1606
#ifdef PADDLE_WITH_HIP
1607 1608
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
                                CudaReluGradFunctor);
1609 1610 1611 1612 1613 1614 1615 1616
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646
#else
REGISTER_OP_CUDA_KERNEL(
    relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                                    ops::CudaReluFunctor<float>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaReluFunctor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::float16>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                             ops::CudaReluGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::float16>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
1647 1648
/* ========================================================================== */

1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
/* ===========================    sigmoid register  ============================
 */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
                                CudaSigmoidGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    sigmoid_grad_grad,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<float>>,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<double>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
1661 1662 1663
                                 ops::SigmoidGradGradFunctor<plat::float16>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<plat::bfloat16>>);
1664 1665 1666 1667 1668 1669 1670 1671

REGISTER_OP_CUDA_KERNEL(
    sigmoid_triple_grad,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<float>>,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<double>>,
    ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
1672 1673 1674 1675
                                 ops::SigmoidTripleGradFunctor<plat::float16>>,
    ops::SigmoidTripleGradKernel<
        plat::CUDADeviceContext,
        ops::SigmoidTripleGradFunctor<plat::bfloat16>>);
1676 1677
/* ========================================================================== */

1678
/* ===========================    tanh register  ============================ */
1679 1680
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, CudaTanhFunctor,
                                CudaTanhGradFunctor);
1681 1682 1683 1684 1685 1686 1687 1688 1689

REGISTER_OP_CUDA_KERNEL(
    tanh_grad_grad,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<float>>,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<double>>,
    ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
                              ops::TanhGradGradFunctor<plat::float16>>);
1690 1691 1692 1693 1694 1695 1696 1697 1698

REGISTER_OP_CUDA_KERNEL(
    tanh_triple_grad,
    ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<float>>,
    ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<double>>,
    ops::TanhTripeGradKernel<plat::CUDADeviceContext,
                             ops::TanhTripleGradFunctor<plat::float16>>);
1699 1700
/* ========================================================================== */

L
lvmengsi 已提交
1701
/* ===========================   sqrt register  ============================= */
1702 1703
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                                CudaSqrtGradFunctor);
L
lvmengsi 已提交
1704 1705 1706 1707 1708 1709 1710 1711

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
1712 1713 1714
                              ops::SqrtGradGradFunctor<plat::float16>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::bfloat16>>);
L
lvmengsi 已提交
1715 1716
/* ========================================================================== */

W
whs 已提交
1717 1718
/* ===========================   rsqrt register  =============================
 */
1719 1720
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
                                CudaRsqrtGradFunctor);
W
whs 已提交
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1732
/* ===========================  square register  ============================ */
1733 1734
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
                                    CudaSquareGradFunctor);
1735 1736 1737 1738 1739 1740 1741 1742

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
1743
                                ops::SquareGradGradFunctor<plat::float16>>,
1744 1745
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
                                ops::SquareGradGradFunctor<plat::bfloat16>>,
1746 1747 1748 1749
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
1750
/* ========================================================================== */
1751 1752 1753 1754 1755

/* ==========================   pow register  ============================ */
REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
1756 1757
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
1758 1759 1760 1761 1762
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
1763 1764
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
1765 1766 1767
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
1768

W
wangzhen38 已提交
1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783
/* ==========================   logit register  ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    logit, ops::LogitKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LogitKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LogitKernel<paddle::platform::CUDADeviceContext,
                     paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    logit_grad,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext,
                         paddle::platform::float16>);
/* ========================================================================== */

1784 1785
/* ==========================   exp register  ============================ */
REGISTER_OP_CUDA_KERNEL(
1786 1787 1788 1789
    exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                   ops::CudaExpFunctor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<double>>,
1790 1791
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
1792 1793
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<plat::float16>>);
1794
REGISTER_OP_CUDA_KERNEL(
1795 1796 1797 1798 1799 1800 1801 1802 1803 1804
    exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                            ops::CudaExpGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<plat::float16>>);
1805 1806
/* ========================================================================== */

R
ronnywang 已提交
1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824
/* ==========================   expm1 register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                     ops::CudaExpm1Functor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                              ops::CudaExpm1GradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

1825
/* ==========================  Log register ==================================*/
1826
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
1827 1828 1829 1830 1831 1832 1833 1834 1835

REGISTER_OP_CUDA_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1836

1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor);                  \
  __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,                      \
          CudaLogSigmoidGradFunctor);                                         \
  __macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor);                  \
  __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor,                      \
          CudaSoftShrinkGradFunctor);                                         \
  __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor);                  \
  __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor);               \
  __macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor);                      \
  __macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor);                      \
  __macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor);                  \
  __macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor);                      \
  __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor);                  \
  __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor);                  \
  __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor);                  \
X
xiaoting 已提交
1853 1854 1855
  __macro(asinh, Asinh, CudaAsinhFunctor, CudaAsinhGradFunctor);              \
  __macro(acosh, Acosh, CudaAcoshFunctor, CudaAcoshGradFunctor);              \
  __macro(atanh, Atanh, CudaAtanhFunctor, CudaAtanhGradFunctor);              \
1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874
  __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor);               \
  __macro(reciprocal, Reciprocal, CudaReciprocalFunctor,                      \
          CudaReciprocalGradFunctor);                                         \
  __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor);              \
  __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor);                  \
  __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor);              \
  __macro(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor);              \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor);              \
  __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor);  \
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);  \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
  __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor,                     \
          CudaTanhShrinkGradFunctor);                                         \
  __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor,                     \
          CudaHardShrinkGradFunctor);                                         \
  __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor,                  \
          CudaHardSigmoidGradFunctor);                                        \
  __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor);              \
1875
  __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor);                  \
1876 1877 1878 1879 1880
  __macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor,      \
          CudaThresholdedReluGradFunctor);                                    \
  __macro(hard_swish, HardSwish, CudaHardSwishFunctor,                        \
          CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944

#ifdef PADDLE_WITH_XPU_KP
#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, op_name, functor,             \
                                       grad_functor)                           \
  REGISTER_OP_KERNEL(                                                          \
      act_type, KP, plat::XPUPlace,                                            \
      ops::ActivationCudaKernel<plat::XPUDeviceContext, ops::functor<float>>); \
  REGISTER_OP_KERNEL(act_type##_grad, KP, plat::XPUPlace,                      \
                     ops::ActivationGradCudaKernel<plat::XPUDeviceContext,     \
                                                   ops::grad_functor<float>>);

REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
                               CudaLeakyReluGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(relu, Relu, CudaReluFunctor,
                               CudaReluGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
                               CudaSigmoidGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(exp, Exp, CudaExpFunctor, CudaExpGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, Reciprocal, CudaReciprocalFunctor,
                               CudaReciprocalGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(softplus, Softplus, CudaSoftplusFunctor,
                               CudaSoftplusGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, HardSwish, CudaHardSwishFunctor,
                               CudaHardSwishGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(elu, Elu, CudaELUFunctor, CudaELUGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(celu, Celu, CudaCELUFunctor,
                               CudaCELUGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                               CudaSqrtGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(square, Square, CudaSquareFunctor,
                               CudaSquareGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(silu, Silu, CudaSiluFunctor,
                               CudaSiluGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,
                               CudaLogSigmoidGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(softshrink, SoftShrink, CudaSoftShrinkFunctor,
                               CudaSoftShrinkGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(ceil, Ceil, CudaCeilFunctor,
                               CudaZeroGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(floor, Floor, CudaFloorFunctor,
                               CudaZeroGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(log1p, Log1p, CudaLog1pFunctor,
                               CudaLog1pGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(brelu, BRelu, CudaBReluFunctor,
                               CudaBReluGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(soft_relu, SoftRelu, CudaSoftReluFunctor,
                               CudaSoftReluGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(softsign, Softsign, CudaSoftsignFunctor,
                               CudaSoftsignGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(relu6, Relu6, CudaRelu6Functor,
                               CudaRelu6GradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(hard_shrink, HardShrink, CudaHardShrinkFunctor,
                               CudaHardShrinkGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(hard_sigmoid, HardSigmoid,
                               CudaHardSigmoidFunctor,
                               CudaHardSigmoidGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(swish, Swish, CudaSwishFunctor,
                               CudaSwishGradFunctor);
REGISTER_ACTIVATION_XPU_KERNEL(thresholded_relu, ThresholdedRelu,
                               CudaThresholdedReluFunctor,
                               CudaThresholdedReluGradFunctor);

#endif  // PADDLE_WITH_XPU_KP