activation_op.kps 45.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/platform/bfloat16.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
17

18 19
#include "paddle/phi/kernels/funcs/activation_functor.h"

20 21 22
namespace paddle {
namespace operators {

23 24 25 26 27
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
28
  __device__ __forceinline__ T operator()(const T arg_x) const {
29
    MPType x = static_cast<MPType>(arg_x);
30 31 32 33 34 35 36 37 38
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
39
  __device__ __forceinline__ T operator()(const T arg_x) const {
40
    MPType x = static_cast<MPType>(arg_x);
41 42 43 44 45 46 47 48 49
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // round(x) = round(x)
50
  __device__ __forceinline__ T operator()(const T arg_x) const {
51
    MPType x = static_cast<MPType>(arg_x);
52 53 54 55
    return static_cast<T>(round(x));
  }
};

56
// GradFunctor for ceil, floor and round
57 58
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
59
  __device__ __forceinline__ T operator()(const T x) const {
60 61 62
    return static_cast<T>(0.0f);
  }

63 64
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kNoDeps;
65 66 67 68 69 70 71 72
  }
};

template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // reciprocal(x) = 1 / x
73
  __device__ __forceinline__ T operator()(const T x) const { return one / x; }
74
};
75

76
template <typename T>
77 78
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
79
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
80
    return -dout * out * out;
81
  }
82

83 84 85
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
86
};
87

88 89 90 91 92
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // exp(x) = exp(x)
93
  __device__ __forceinline__ T operator()(const T arg_x) const {
94
    MPType x = static_cast<MPType>(arg_x);
95 96 97
    return static_cast<T>(exp(x));
  }
};
98 99

template <typename T>
100 101
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
102
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
103
    return dout * out;
104
  }
105

106 107 108
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
109
};
110

R
ronnywang 已提交
111 112 113 114 115
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // expm1(x) = expm1(x)
116
  __device__ __forceinline__ T operator()(const T arg_x) const {
117
    MPType x = static_cast<MPType>(arg_x);
R
ronnywang 已提交
118 119 120 121 122 123 124
    return static_cast<T>(expm1(x));
  }
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
125
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
126
    return dout * out + dout;
R
ronnywang 已提交
127 128
  }

129 130 131
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
R
ronnywang 已提交
132 133
};

134 135 136
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
  // square(x) = x * x
137
  __device__ __forceinline__ T operator()(const T x) const { return x * x; }
138
};
139

140 141 142 143 144
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
145
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
146
    return dout * two * x;
147 148
  }

149
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
150 151
};

152 153 154 155 156
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // sqrt(x) = sqrt(x)
157
  __device__ __forceinline__ T operator()(const T arg_x) const {
158
    MPType x = static_cast<MPType>(arg_x);
159 160 161
    return static_cast<T>(sqrt(x));
  }
};
162

163 164 165 166 167
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
168
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
169
    return one_half * dout / out;
170 171
  }

172 173 174
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
175
};
176

177 178 179 180 181
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // rsqrt(x) = rsqrt(x)
182
  __device__ __forceinline__ T operator()(const T arg_x) const {
183
    MPType x = static_cast<MPType>(arg_x);
184 185 186 187 188 189 190 191
    return static_cast<T>(rsqrt(x));
  }
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
  T minus_one_half = static_cast<T>(-0.5f);

192
  // dx = -0.5 * dout * out^3
193
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
194
    return minus_one_half * dout * out * out * out;
195 196
  }

197 198 199
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
200
};
201

202 203 204 205 206 207 208 209 210 211 212 213
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
214
  __device__ __forceinline__ T operator()(const T arg_x) const {
215
    MPType x = static_cast<MPType>(arg_x);
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
235 236
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
237 238
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
239 240 241 242 243
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

244 245 246
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
247 248 249 250 251 252 253 254 255 256 257 258 259
};

template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // stanh(x) = b * tanh(a * x)
260
  __device__ __forceinline__ T operator()(const T arg_x) const {
261
    MPType x = static_cast<MPType>(arg_x);
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    return static_cast<T>(b * tanh(a * x));
  }
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
280 281
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
282 283
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
284 285 286 287 288 289
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

290
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
291 292 293 294 295 296 297 298 299 300 301 302 303 304
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
305
  __device__ __forceinline__ T operator()(const T arg_x) const {
306
    MPType x = static_cast<MPType>(arg_x);
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
  }
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
326 327
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
328 329
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
330 331 332
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
333
    return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
334 335
  }

336
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
337 338 339 340 341 342 343
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // softsign(x) = x / (1 + abs(x))
344
  __device__ __forceinline__ T operator()(const T x) const {
345
    return x / (one + abs(x));
346 347 348 349 350 351 352 353
  }
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
354
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
355 356
    T temp = one + abs(x);
    return dout / (temp * temp);
357 358
  }

359
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
360 361 362 363 364 365 366 367 368 369 370 371
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
372
  __device__ __forceinline__ T operator()(const T x) const {
373
    T t = static_cast<T>(threshold);
374
    return x <= zero ? zero : (x < t ? x : t);
375 376 377 378 379 380 381 382 383 384 385 386 387
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
388
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
389
    T t = static_cast<T>(threshold);
390
    return (out > zero && out < t) ? dout : zero;
391 392
  }

393 394 395
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
396 397 398 399 400 401 402 403 404 405 406 407 408
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
409
  __device__ __forceinline__ T operator()(const T arg_x) const {
410
    MPType x = static_cast<MPType>(arg_x);
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
427 428
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
429 430
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
431 432 433 434 435 436 437 438
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

439
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
440 441
};

442 443 444 445 446 447 448 449 450 451 452 453 454 455
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // mish(x) = x * tanh(softplus(x))
  // softplus(x) = x, if x > threshold
  //             = ln(1 + exp(x)), otherwise
  // Inputs: args[0], the input x
456
  __device__ __forceinline__ T operator()(const T arg_x) const {
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    return static_cast<T>(x * tanh(sp));
  }
};

template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
  // sp = softplus(x)
  // Inputs: args[0], the input dout
  //         args[1], the input x
477 478
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
479 480 481 482 483 484 485 486 487
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    MPType gsp =
        (x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
    MPType tsp = tanh(sp);
    return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
  }

488
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
489 490
};

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
506
  __device__ __forceinline__ T operator()(const T x) const {
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
    T t = static_cast<T>(threshold);
    T temp = x + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < t ? temp_max : t;
    return temp_min * x / static_cast<T>(scale);
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  T two = static_cast<T>(2.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
532
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
533 534 535 536
    T o = static_cast<T>(offset);
    T s = static_cast<T>(scale);
    T temp1 = static_cast<T>(x + o > zero);
    T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
537
    return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
538 539
  }

540
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
541 542
};

543 544 545 546 547 548 549 550 551 552 553 554
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
555
  __device__ __forceinline__ T operator()(const T arg_x) const {
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
578 579
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
580 581 582 583 584 585 586 587 588 589 590 591 592
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

593
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
594 595
};

596
template <typename DeviceContext, typename Functor>
597
class ActivationCudaKernel
598 599 600
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
601 602
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
603
    framework::Tensor* out = nullptr;
604 605 606 607 608 609
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
610 611
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
612
      *attr.second = ctx.Attr<float>(attr.first);
613
    }
614 615
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
616 617 618 619
  }
};

template <typename DeviceContext, typename Functor>
620
class ActivationGradCudaKernel
621 622 623
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
624
  void Compute(const framework::ExecutionContext& ctx) const override {
625 626 627
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
628
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
629
                                                    &d_x);
630 631 632 633 634 635 636 637 638 639
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
640

641 642
    if (static_cast<int>(Functor::FwdDeps()) ==
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
643
      // Only need forward output Out
644
      ins.push_back(out);
645 646
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
647
    } else if (static_cast<int>(Functor::FwdDeps()) ==
648
               static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
649
      // Only need forward input X
650
      ins.push_back(x);
651 652
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
653
    } else {
654 655
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
656 657 658 659
    }
  }
};

660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
USE_PHI_FUNCTOR(CudaCos)
USE_PHI_FUNCTOR(CudaTan)
USE_PHI_FUNCTOR(CudaAcos)
USE_PHI_FUNCTOR(CudaSin)
USE_PHI_FUNCTOR(CudaAsin)
USE_PHI_FUNCTOR(CudaAtan)
USE_PHI_FUNCTOR(CudaSinh)
USE_PHI_FUNCTOR(CudaCosh)
USE_PHI_FUNCTOR(CudaAsinh)
USE_PHI_FUNCTOR(CudaAcosh)
USE_PHI_FUNCTOR(CudaAtanh)
USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
Y
YuanRisheng 已提交
675 676 677 678 679
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
Y
YuanRisheng 已提交
680 681 682
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
683 684 685 686
USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)
Y
YuanRisheng 已提交
687 688 689 690

template <typename T>
using CudaELUGradNegativeAlphaFunctor =
    phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
691

692 693 694
}  // namespace operators
}  // namespace paddle

695
namespace ops = paddle::operators;
696 697
namespace plat = paddle::platform;

698 699
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
700
  REGISTER_OP_CUDA_KERNEL(                                                     \
701 702 703 704 705
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
706 707 708
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
709
  REGISTER_OP_CUDA_KERNEL(                                                     \
710 711 712 713 714 715
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
716 717 718
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
719

720 721 722 723 724 725 726 727 728 729 730 731
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
732 733 734
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
735 736 737 738 739 740 741 742 743 744 745
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
746 747 748
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
749

D
Double_V 已提交
750 751
/* ========================================================================== */

752 753 754 755 756 757 758 759 760 761 762 763 764
/* ======================== celu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
                                CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                                              ops::CELUGradGradFunctor<float>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<double>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

L
lvmengsi 已提交
765
/* ===========================   sqrt register  ============================= */
766 767
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                                CudaSqrtGradFunctor);
L
lvmengsi 已提交
768 769 770 771 772 773 774 775

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
776 777 778
                              ops::SqrtGradGradFunctor<plat::float16>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::bfloat16>>);
L
lvmengsi 已提交
779 780
/* ========================================================================== */

W
whs 已提交
781 782
/* ===========================   rsqrt register  =============================
 */
783 784
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
                                CudaRsqrtGradFunctor);
W
whs 已提交
785 786 787 788 789 790 791 792 793 794 795

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

796
/* ===========================  square register  ============================ */
797 798
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
                                    CudaSquareGradFunctor);
799 800 801 802 803 804 805 806

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
807
                                ops::SquareGradGradFunctor<plat::float16>>,
808 809
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
                                ops::SquareGradGradFunctor<plat::bfloat16>>,
810 811 812 813
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
814
/* ========================================================================== */
815 816 817 818 819

/* ==========================   pow register  ============================ */
REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
820 821
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
822 823 824 825 826
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
827 828
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
829 830 831
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
832

W
wangzhen38 已提交
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
/* ==========================   logit register  ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    logit, ops::LogitKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LogitKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LogitKernel<paddle::platform::CUDADeviceContext,
                     paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
    logit_grad,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext,
                         paddle::platform::float16>);
/* ========================================================================== */

848 849
/* ==========================   exp register  ============================ */
REGISTER_OP_CUDA_KERNEL(
850 851 852 853
    exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                   ops::CudaExpFunctor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<double>>,
854 855
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
856 857
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<plat::float16>>);
858
REGISTER_OP_CUDA_KERNEL(
859 860 861 862 863 864 865 866 867 868
    exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                            ops::CudaExpGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<plat::float16>>);
869 870
/* ========================================================================== */

R
ronnywang 已提交
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
/* ==========================   expm1 register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                     ops::CudaExpm1Functor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                              ops::CudaExpm1GradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor,                      \
          CudaSoftShrinkGradFunctor);                                         \
  __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor);                  \
  __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor);               \
  __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor);               \
  __macro(reciprocal, Reciprocal, CudaReciprocalFunctor,                      \
          CudaReciprocalGradFunctor);                                         \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor);              \
  __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor);  \
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);  \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
  __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor,                     \
          CudaTanhShrinkGradFunctor);                                         \
  __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor,                     \
          CudaHardShrinkGradFunctor);                                         \
  __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor);              \
907
  __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor);                  \
908 909 910
  __macro(hard_swish, HardSwish, CudaHardSwishFunctor,                        \
          CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
911 912

#ifdef PADDLE_WITH_XPU_KP
913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128
REGISTER_OP_KERNEL(
    brelu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaBReluFunctor<float>>);
REGISTER_OP_KERNEL(
    brelu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaBReluGradFunctor<float>>);

REGISTER_OP_KERNEL(ceil, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCeilFunctor<float>>);
REGISTER_OP_KERNEL(
    ceil_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
    celu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaCELUGradFunctor<float>>);

REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaELUFunctor<float>>);
REGISTER_OP_KERNEL(
    elu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaELUGradFunctor<float>>);

REGISTER_OP_KERNEL(exp, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaExpFunctor<float>>);
REGISTER_OP_KERNEL(
    exp_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaExpGradFunctor<float>>);

REGISTER_OP_KERNEL(floor, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaFloorFunctor<float>>);
REGISTER_OP_KERNEL(
    floor_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_shrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_shrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_sigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(hard_swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaHardSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    leaky_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaLeakyReluFunctor<float>>);
REGISTER_OP_KERNEL(
    leaky_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaLeakyReluGradFunctor<float>>);

REGISTER_OP_KERNEL(log, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLogFunctor<float>>);
REGISTER_OP_KERNEL(
    log_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogGradFunctor<float>>);

REGISTER_OP_KERNEL(log1p, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLog1pFunctor<float>>);
REGISTER_OP_KERNEL(
    log1p_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLog1pGradFunctor<float>>);

REGISTER_OP_KERNEL(
    logsigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaLogSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    logsigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(
    reciprocal, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaReciprocalFunctor<float>>);
REGISTER_OP_KERNEL(
    reciprocal_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaReciprocalGradFunctor<float>>);

REGISTER_OP_KERNEL(
    relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaReluFunctor<float>>);
REGISTER_OP_KERNEL(
    relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaReluGradFunctor<float>>);

REGISTER_OP_KERNEL(relu6, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaRelu6Functor<float>>);
REGISTER_OP_KERNEL(
    relu6_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaRelu6GradFunctor<float>>);

REGISTER_OP_KERNEL(sigmoid, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(silu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSiluFunctor<float>>);
REGISTER_OP_KERNEL(
    silu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSiluGradFunctor<float>>);

REGISTER_OP_KERNEL(soft_relu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftReluFunctor<float>>);
REGISTER_OP_KERNEL(
    soft_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftReluGradFunctor<float>>);

REGISTER_OP_KERNEL(softplus, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftplusFunctor<float>>);
REGISTER_OP_KERNEL(
    softplus_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftplusGradFunctor<float>>);

REGISTER_OP_KERNEL(
    softshrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaSoftShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    softshrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(softsign, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftsignFunctor<float>>);
REGISTER_OP_KERNEL(
    softsign_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftsignGradFunctor<float>>);

REGISTER_OP_KERNEL(sqrt, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSqrtFunctor<float>>);
REGISTER_OP_KERNEL(
    sqrt_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSqrtGradFunctor<float>>);

REGISTER_OP_KERNEL(square, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSquareFunctor<float>>);
REGISTER_OP_KERNEL(
    square_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSquareGradFunctor<float>>);

REGISTER_OP_KERNEL(swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    thresholded_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaThresholdedReluFunctor<float>>);
REGISTER_OP_KERNEL(
    thresholded_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaThresholdedReluGradFunctor<float>>);
1129 1130

#endif  // PADDLE_WITH_XPU_KP