activation_op.cu 61.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/operators/math/math_cuda_utils.h"
16
#include "paddle/fluid/platform/bfloat16.h"
17
#include "paddle/fluid/platform/cuda_device_function.h"
18

19 20 21
namespace paddle {
namespace operators {

22 23 24 25 26
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // relu(x) = max(x, 0)
27 28 29
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] > zero ? args[0] : zero;
30 31
  }
};
32 33

template <typename T>
34 35 36 37
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // dx = dout * (out > 0)
38 39 40 41
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[1] > zero ? args[0] : zero;
42 43 44
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
45 46
};

47 48 49 50 51 52 53 54 55 56
template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // leakyrelu(x) = x > 0 ? x : alpha * x
57 58 59
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] > zero ? args[0] : static_cast<T>(alpha) * args[0];
60
  }
61 62
};

63 64 65 66 67 68 69 70 71 72
template <typename T>
struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout * (x > 0 ? 1 : alpha)
73 74 75 76
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[1] > zero ? args[0] : static_cast<T>(alpha) * args[0];
77 78 79
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
80 81 82
};

template <typename T>
83 84 85 86 87
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
88 89 90
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
91 92 93
    return static_cast<T>(one / (one + exp(-x)));
  }
};
94

95 96 97 98 99
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * out * (1 - out)
100 101 102 103
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] * args[1] * (one - args[1]);
104
  }
105

106
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
107 108
};

109 110
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
111
  // MPType means Compute Type
112 113 114 115
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // silu(x) = x / (1 + exp(-x))
116 117 118
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
119 120 121
    return static_cast<T>(x / (one + exp(-x)));
  }
};
122 123

template <typename T>
124 125 126 127 128
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
129 130 131 132 133
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
134 135 136
    MPType temp = one / (one + exp(-x));
    return static_cast<T>(dout * (temp * (one + x * (one - temp))));
  }
137

138 139
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
140

141 142 143 144 145 146 147 148 149
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // logsigmoid(x) = log(1 / (1 + exp(-x)))
  // For numerical stability,
  // logsigmoid(x) =
  //          - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
150 151 152
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
153 154 155 156
    MPType temp = x > zero ? zero : -x;
    return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
  }
};
157 158

template <typename T>
159 160 161 162 163 164 165 166
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // dx = dout * exp(-x) / (1 + exp(-x))
  // For numerical stability:
  // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
  // 0)))
167 168 169 170 171
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
172 173 174 175
    MPType temp1 = x > zero ? zero : -x;
    MPType temp2 = exp(-x - temp1);
    return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
  }
176

177 178 179 180 181 182 183 184
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // atan(x) = atan(x)
185 186 187
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
188 189 190 191 192 193 194 195 196
    return static_cast<T>(atan(x));
  }
};

template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x^2)
197 198 199 200
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / (one + args[1] * args[1]);
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // softshrink(x) = x - lambda, if x > lambda;
  //                 x + lambda, if x < -lambda;
  //                 0, otherwise.
217 218 219
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[0];
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    T l = static_cast<T>(lambda);
    T temp1 = static_cast<T>(x > l);
    T temp2 = static_cast<T>(x < -l);
    return temp1 * (x - l) + temp2 * (x + l);
  }
};

template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // dx = dout, if x > lambda or x < -lambda else 0
237 238 239 240
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[1];
241
    T l = static_cast<T>(lambda);
242
    return (x >= -l && x <= l) ? zero : args[0];
243 244 245 246 247 248 249 250 251 252
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
253 254 255
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
256 257 258 259 260 261 262 263 264
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
265 266 267
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
268 269 270 271 272 273 274 275 276
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // round(x) = round(x)
277 278 279
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
280 281 282 283
    return static_cast<T>(round(x));
  }
};

284
// grad functor for ceil, floor and round
285 286
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
287
  __device__ __forceinline__ T operator()(const T* args) const {
288 289 290 291 292 293 294 295 296 297 298
    return static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
};

template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cos(x) = cos(x)
299 300 301
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
302 303 304 305 306 307 308 309 310
    return static_cast<T>(cos(x));
  }
};

template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * (-sin(x))
311 312 313 314 315
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
316 317 318 319 320 321 322 323 324 325 326
    return static_cast<T>(-dout * sin(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sin(x) = sin(x)
327 328 329
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
330 331 332 333 334 335 336 337 338
    return static_cast<T>(sin(x));
  }
};

template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cos(x)
339 340 341 342 343
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
344 345 346 347 348 349 350 351 352 353 354
    return static_cast<T>(dout * cos(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tan(x) = tan(x)
355 356 357
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
358 359 360 361 362 363 364 365 366
    return static_cast<T>(tan(x));
  }
};

template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout / cos(x)^2
367 368 369 370 371
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
372 373 374 375 376 377 378 379 380 381 382
    return static_cast<T>(dout / (cos(x) * cos(x)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // asin(x) = asin(x)
383 384 385
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
386 387 388 389 390 391 392 393 394 395
    return static_cast<T>(asin(x));
  }
};

template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout / sqrt(1 - x^2)
396 397 398 399 400
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
401 402 403 404 405 406 407 408 409 410 411
    return static_cast<T>(dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // acos(x) = acos(x)
412 413 414
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
415 416 417 418 419 420 421 422 423 424
    return static_cast<T>(acos(x));
  }
};

template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = -dout / sqrt(1 - x^2)
425 426 427 428 429
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
430 431 432 433 434
    return static_cast<T>(-dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
435

436 437 438 439 440
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // cosh(x) = cosh(x)
441 442 443
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
444 445 446 447 448 449 450 451 452
    return static_cast<T>(cosh(x));
  }
};

template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * sinh(x)
453 454 455 456 457
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
458 459 460 461 462 463 464 465 466 467 468
    return static_cast<T>(dout * sinh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sinh(x) = sinh(x)
469 470 471
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
472 473 474 475 476 477 478 479 480
    return static_cast<T>(sinh(x));
  }
};

template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * cosh(x)
481 482 483 484 485
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
486 487 488 489 490 491 492 493 494 495 496
    return static_cast<T>(dout * cosh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanh(x) = tanh(x)
497 498 499
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
500
    return static_cast<T>(tanh(x));
501
  }
502
};
503

504 505 506 507 508
template <typename T>
struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * (1 - out^2)
509 510 511 512 513
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    T dout = static_cast<T>(args[0]);
    T out = static_cast<T>(args[1]);
514
    return dout * (one - out * out);
515 516 517 518 519
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

520 521 522 523 524
template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // reciprocal(x) = 1 / x
525 526 527 528
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return one / args[0];
  }
529
};
530

531
template <typename T>
532 533
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
534 535 536 537
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return -args[0] * args[1] * args[1];
538
  }
539

540 541
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
542

543 544 545 546 547
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // exp(x) = exp(x)
548 549 550
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
551 552 553
    return static_cast<T>(exp(x));
  }
};
554 555

template <typename T>
556 557
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
558 559 560 561
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] * args[1];
562
  }
563

564 565
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
566

R
ronnywang 已提交
567 568 569 570 571
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // expm1(x) = expm1(x)
572 573 574
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
R
ronnywang 已提交
575 576 577 578 579 580 581
    return static_cast<T>(expm1(x));
  }
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
582 583 584 585
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] * args[1] + args[0];
R
ronnywang 已提交
586 587 588 589 590
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

591 592 593 594 595
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log(x) = log(x)
596 597 598
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
599 600 601 602 603 604 605
    return static_cast<T>(log(x));
  }
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout / x
606 607 608 609
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / args[1];
610 611
  }

612 613 614 615 616 617
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
  // square(x) = x * x
618 619 620 621
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] * args[0];
  }
622
};
623

624 625 626 627 628
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
629 630 631 632
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] * two * args[1];
633 634 635 636 637
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

638 639 640 641 642
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sqrt(x) = sqrt(x)
643 644 645
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
646 647 648
    return static_cast<T>(sqrt(x));
  }
};
649

650 651 652 653 654
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
655 656 657 658
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    return one_half * args[0] / args[1];
659 660 661 662
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
663

664 665 666 667 668
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // rsqrt(x) = rsqrt(x)
669 670 671
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
672 673 674 675 676 677 678 679
    return static_cast<T>(rsqrt(x));
  }
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
  T minus_one_half = static_cast<T>(-0.5f);

680 681 682 683 684 685
  // dx = dout * -0.5 / out^3
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    T out = args[1];
    return minus_one_half * args[0] * out * out * out;
686 687 688 689
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
690

691 692 693 694 695 696
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // log1p(x) = log(1 + x)
697 698 699
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
700 701 702 703 704 705 706 707 708
    return static_cast<T>(log(one + x));
  }
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x)
709 710 711 712
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / (one + args[1]);
713 714 715 716 717 718 719 720 721 722
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log2(x) = log2(x)
723 724 725
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
726 727 728 729 730 731 732 733 734 735
    return static_cast<T>(log2(x));
  }
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

  // dx = dout / (x * log(2))
736 737 738 739
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / (args[1] * log_two);
740 741 742 743 744 745 746 747 748 749
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log10(x) = log10(x)
750 751 752
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
753 754 755 756 757 758 759 760 761 762
    return static_cast<T>(log10(x));
  }
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

  // dx = dout / (x * log(10))
763 764 765 766
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / (args[1] * log_ten);
767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaBReluFunctor : public BaseActivationFunctor<T> {
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // brelu(x) = min(max(x, t_min), t_max)
782 783 784
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[0];
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    T temp_max = x > t_min_cast ? x : t_min_cast;
    T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast;
    return temp_min;
  }
};

template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // dx = (x > t_min && x < t_max) ? dout : 0
804 805 806 807 808
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T dout = args[0];
    T x = args[1];
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    return (x > t_min_cast && x < t_max_cast) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
828
  // Inputs: args[0], the input x
829
  // threshold should not be negative
830 831
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
850 851
  // Inputs: args[0], the input dout
  //         args[1], the input out
852
  // threshold should not be negative
853 854 855
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType out = static_cast<MPType>(args[1]);
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // stanh(x) = b * tanh(a * x)
875 876 877
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    return static_cast<T>(b * tanh(a * x));
  }
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
896 897 898 899 900
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
922 923 924
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
  }
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
944 945 946 947 948
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
949 950 951
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
952
    return x_beta > t ? args[0] : static_cast<T>(dout / (one + exp(-x_beta)));
953 954 955 956 957 958 959 960 961 962
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // softsign(x) = x / (1 + abs(x))
963 964 965
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] / (one + abs(args[0]));
966 967 968 969 970 971 972 973
  }
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
974 975 976 977 978
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T temp = one + abs(args[1]);
    return args[0] / (temp * temp);
979 980 981 982 983 984 985 986 987 988 989 990 991 992 993
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
994 995
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
996
    T t = static_cast<T>(threshold);
997
    return args[0] <= zero ? zero : (args[0] < t ? args[0] : t);
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
1011 1012 1013
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
1014
    T t = static_cast<T>(threshold);
1015
    return (args[1] > zero && args[1] < t) ? args[0] : zero;
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // tanhshrink(x) = x - tanh(x)
1026 1027 1028
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
1029 1030 1031 1032 1033 1034 1035 1036 1037
    return static_cast<T>(x - tanh(x));
  }
};

template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // dx = dout * tanh(x)^2
1038 1039 1040 1041 1042
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058
    return static_cast<T>(dout * tanh(x) * tanh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
1059 1060 1061
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[0];
1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
    T t = static_cast<T>(threshold);
    return (x > -t && x < t) ? zero : x;
  }
};

template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (x > -threshold && x < threshold) ? 0 : dout
1077 1078 1079 1080
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[1];
1081
    T t = static_cast<T>(threshold);
1082
    return (x > -t && x < t) ? zero : args[0];
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // hard_sigmoid(x) = 0, when x <= -3
  //                   1, when x >= 3
  //                   x * slope + offset, otherwise
1102 1103 1104
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T temp = args[0] * static_cast<T>(slope) + static_cast<T>(offset);
1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < one ? temp_max : one;
    return temp_min;
  }
};

template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // dx = (out > 0 && out < 1) ? dout * slope : 0
1123 1124 1125 1126 1127
  // Inputs: args[0], the input dout
  //         args[1], the input out
  __device__ __forceinline__ T operator()(const T* args) const {
    T out = args[1];
    return (out > zero && out < one) ? args[0] * static_cast<T>(slope) : zero;
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
1144 1145 1146
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType x = static_cast<MPType>(args[0]);
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
1163 1164 1165 1166 1167
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // thresholded_relu(x) = x > threshold ? x : 0
1189 1190 1191
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[0] > static_cast<T>(threshold) ? args[0] : zero;
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
  }
};

template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = x > threshold ? dout : 0
1205 1206 1207 1208
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    return args[1] > static_cast<T>(threshold) ? args[0] : zero;
1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
1229 1230 1231
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[0];
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256
    T t = static_cast<T>(threshold);
    T temp = x + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < t ? temp_max : t;
    return temp_min * x / static_cast<T>(scale);
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  T two = static_cast<T>(2.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
1257 1258 1259 1260
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    T x = args[1];
1261 1262 1263 1264
    T o = static_cast<T>(offset);
    T s = static_cast<T>(scale);
    T temp1 = static_cast<T>(x + o > zero);
    T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
1265
    return args[0] * (temp1 * temp2 * (two * x + o) / s + one - temp2);
1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
1283 1284 1285
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    CT x = static_cast<CT>(args[0]);
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
    CT temp = static_cast<CT>(alpha) * (exp(x) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
  // dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
  // dx = 0, if alpha <= 0 and x <=0
1307 1308 1309 1310 1311
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T* args) const {
    MPType dout = static_cast<MPType>(args[0]);
    MPType x = static_cast<MPType>(args[1]);
1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
                temp_a_neg * temp_x_pos * (one + a * exp(x))));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1325
template <typename DeviceContext, typename Functor>
1326
class ActivationCudaKernel
1327 1328 1329
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1330 1331
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
1332
    framework::Tensor* out = nullptr;
1333 1334 1335 1336 1337 1338
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
1339 1340
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
1341
      *attr.second = ctx.Attr<float>(attr.first);
1342
    }
1343 1344
    LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
        dev_ctx, ins, &outs, functor);
1345 1346 1347 1348
  }
};

template <typename DeviceContext, typename Functor>
1349
class ActivationGradCudaKernel
1350 1351 1352
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
1353
  void Compute(const framework::ExecutionContext& ctx) const override {
1354 1355 1356
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
1357
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
1358
                                                    &d_x);
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
1369 1370 1371

    if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
      // Only need forward output Out
1372
      ins.push_back(out);
1373
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1374
          dev_ctx, ins, &outs, functor);
1375 1376 1377
    } else if (static_cast<int>(Functor::FwdDeps()) ==
               static_cast<int>(kDepX)) {
      // Only need forward input X
1378
      ins.push_back(x);
1379
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1380
          dev_ctx, ins, &outs, functor);
1381
    } else {
1382
      LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
1383
          dev_ctx, ins, &outs, functor);
1384 1385 1386 1387 1388 1389 1390
    }
  }
};

}  // namespace operators
}  // namespace paddle

1391
namespace ops = paddle::operators;
1392 1393
namespace plat = paddle::platform;

1394 1395
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
1396
  REGISTER_OP_CUDA_KERNEL(                                                     \
1397 1398 1399 1400 1401 1402
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::float16>>);                 \
1403
  REGISTER_OP_CUDA_KERNEL(                                                     \
1404 1405 1406 1407 1408 1409 1410
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::float16>>);
1411

1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::float16>>);                 \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::float16>>);

1438
/* ======================== leaky relu register  ============================ */
1439 1440
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
                                CudaLeakyReluGradFunctor);
1441 1442 1443 1444 1445 1446 1447 1448 1449

REGISTER_OP_CUDA_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
1450
/* ========================================================================== */
1451

D
Double_V 已提交
1452
/* ======================== elu register  ============================ */
1453
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
D
Double_V 已提交
1454 1455 1456 1457 1458 1459 1460 1461 1462 1463

REGISTER_OP_CUDA_KERNEL(
    elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::ELUGradGradFunctor<float>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<double>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1464
/* ===========================    relu register  ============================ */
1465
#ifdef PADDLE_WITH_HIP
1466 1467
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
                                CudaReluGradFunctor);
1468 1469 1470 1471 1472 1473 1474 1475
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505
#else
REGISTER_OP_CUDA_KERNEL(
    relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                                    ops::CudaReluFunctor<float>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaReluFunctor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::float16>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                             ops::CudaReluGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::float16>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
1506 1507
/* ========================================================================== */

1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522
/* ===========================    sigmoid register  ============================
 */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
                                CudaSigmoidGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    sigmoid_grad_grad,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<float>>,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<double>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1523
/* ===========================    tanh register  ============================ */
1524 1525
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, CudaTanhFunctor,
                                CudaTanhGradFunctor);
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536

REGISTER_OP_CUDA_KERNEL(
    tanh_grad_grad,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<float>>,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<double>>,
    ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
                              ops::TanhGradGradFunctor<plat::float16>>);
/* ========================================================================== */

L
lvmengsi 已提交
1537
/* ===========================   sqrt register  ============================= */
1538 1539
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                                CudaSqrtGradFunctor);
L
lvmengsi 已提交
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

W
whs 已提交
1551 1552
/* ===========================   rsqrt register  =============================
 */
1553 1554
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
                                CudaRsqrtGradFunctor);
W
whs 已提交
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1566
/* ===========================  square register  ============================ */
1567 1568
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
                                    CudaSquareGradFunctor);
1569 1570 1571 1572 1573 1574 1575 1576

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
1577 1578 1579 1580 1581
                                ops::SquareGradGradFunctor<plat::float16>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
1582
/* ========================================================================== */
1583 1584 1585 1586 1587

/* ==========================   pow register  ============================ */
REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
1588 1589
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
1590 1591 1592 1593 1594
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
1595 1596
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
1597 1598 1599
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
1600 1601 1602

/* ==========================   exp register  ============================ */
REGISTER_OP_CUDA_KERNEL(
1603 1604 1605 1606
    exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                   ops::CudaExpFunctor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<double>>,
1607 1608
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
1609 1610
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<plat::float16>>);
1611
REGISTER_OP_CUDA_KERNEL(
1612 1613 1614 1615 1616 1617 1618 1619 1620 1621
    exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                            ops::CudaExpGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<plat::float16>>);
1622 1623
/* ========================================================================== */

R
ronnywang 已提交
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641
/* ==========================   expm1 register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                     ops::CudaExpm1Functor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                              ops::CudaExpm1GradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

1642
/* ==========================  Log register ==================================*/
1643
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
1644 1645 1646 1647 1648 1649 1650 1651 1652

REGISTER_OP_CUDA_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1653

1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor);                  \
  __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,                      \
          CudaLogSigmoidGradFunctor);                                         \
  __macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor);                  \
  __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor,                      \
          CudaSoftShrinkGradFunctor);                                         \
  __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor);                  \
  __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor);               \
  __macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor);                      \
  __macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor);                      \
  __macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor);                  \
  __macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor);                      \
  __macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor);                  \
  __macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor);                  \
  __macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor);                  \
  __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor);               \
  __macro(reciprocal, Reciprocal, CudaReciprocalFunctor,                      \
          CudaReciprocalGradFunctor);                                         \
  __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor);              \
  __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor);                  \
  __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor);              \
  __macro(brelu, BRelu, CudaBReluFunctor, CudaBReluGradFunctor);              \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor);              \
  __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor);  \
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);  \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
  __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor,                     \
          CudaTanhShrinkGradFunctor);                                         \
  __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor,                     \
          CudaHardShrinkGradFunctor);                                         \
  __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor,                  \
          CudaHardSigmoidGradFunctor);                                        \
  __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor);              \
  __macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor,      \
          CudaThresholdedReluGradFunctor);                                    \
  __macro(hard_swish, HardSwish, CudaHardSwishFunctor,                        \
          CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)