activation_op.kps 49.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/platform/bfloat16.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
17

18 19
#include "paddle/phi/kernels/funcs/activation_functor.h"

20 21 22 23
namespace paddle {
namespace operators {

template <typename T>
24 25 26 27 28
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
29
  __device__ __forceinline__ T operator()(const T arg_x) const {
30
    MPType x = static_cast<MPType>(arg_x);
31 32 33
    return static_cast<T>(one / (one + exp(-x)));
  }
};
34

35 36 37 38 39
template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * out * (1 - out)
40
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
41
    return dout * out * (one - out);
42
  }
43

44 45 46
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
47 48
};

49 50 51 52 53 54 55 56 57
template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // logsigmoid(x) = log(1 / (1 + exp(-x)))
  // For numerical stability,
  // logsigmoid(x) =
  //          - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
58
  __device__ __forceinline__ T operator()(const T arg_x) const {
59
    MPType x = static_cast<MPType>(arg_x);
60 61 62 63
    MPType temp = x > zero ? zero : -x;
    return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
  }
};
64 65

template <typename T>
66 67 68 69 70 71 72 73
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // dx = dout * exp(-x) / (1 + exp(-x))
  // For numerical stability:
  // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
  // 0)))
74 75
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
76 77
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
78 79 80 81
    MPType temp1 = x > zero ? zero : -x;
    MPType temp2 = exp(-x - temp1);
    return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
  }
82

83
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
84 85 86 87 88 89 90
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
91
  __device__ __forceinline__ T operator()(const T arg_x) const {
92
    MPType x = static_cast<MPType>(arg_x);
93 94 95 96 97 98 99 100 101
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
102
  __device__ __forceinline__ T operator()(const T arg_x) const {
103
    MPType x = static_cast<MPType>(arg_x);
104 105 106 107 108 109 110 111 112
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // round(x) = round(x)
113
  __device__ __forceinline__ T operator()(const T arg_x) const {
114
    MPType x = static_cast<MPType>(arg_x);
115 116 117 118
    return static_cast<T>(round(x));
  }
};

119
// GradFunctor for ceil, floor and round
120 121
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
122
  __device__ __forceinline__ T operator()(const T x) const {
123 124 125
    return static_cast<T>(0.0f);
  }

126 127
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kNoDeps;
128 129 130
  }
};

131
template <typename T>
132 133
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
134
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
135
    return -dout * out * out;
136
  }
137

138 139 140
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
R
ronnywang 已提交
141 142
};

143 144 145 146 147
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log(x) = log(x)
148
  __device__ __forceinline__ T operator()(const T arg_x) const {
149
    MPType x = static_cast<MPType>(arg_x);
150 151 152 153 154 155 156
    return static_cast<T>(log(x));
  }
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout / x
157
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
158
    return dout / x;
159 160
  }

161
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
162 163 164 165 166 167 168
};

template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
169
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
170
    return dout * two * x;
171 172
  }

173
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
174 175
};

176 177 178 179 180
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
181
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
182
    return one_half * dout / out;
183 184
  }

185 186 187
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
188
};
189

190 191 192 193
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
  T minus_one_half = static_cast<T>(-0.5f);

194
  // dx = -0.5 * dout * out^3
195
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
196
    return minus_one_half * dout * out * out * out;
197 198
  }

199 200 201
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
202
};
203

204 205 206 207 208 209
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // log1p(x) = log(1 + x)
210
  __device__ __forceinline__ T operator()(const T arg_x) const {
211
    MPType x = static_cast<MPType>(arg_x);
212 213 214 215 216 217 218 219 220
    return static_cast<T>(log(one + x));
  }
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x)
221
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
222
    return dout / (one + x);
223 224
  }

225
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
226 227 228 229 230 231 232
};

template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log2(x) = log2(x)
233
  __device__ __forceinline__ T operator()(const T arg_x) const {
234
    MPType x = static_cast<MPType>(arg_x);
235 236 237 238 239 240 241 242 243 244
    return static_cast<T>(log2(x));
  }
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

  // dx = dout / (x * log(2))
245
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
246
    return dout / (x * log_two);
247 248
  }

249
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
250 251 252 253 254 255 256
};

template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;

  // log10(x) = log10(x)
257
  __device__ __forceinline__ T operator()(const T arg_x) const {
258
    MPType x = static_cast<MPType>(arg_x);
259 260 261 262 263 264 265 266 267 268
    return static_cast<T>(log10(x));
  }
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

  // dx = dout / (x * log(10))
269
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
270
    return dout / (x * log_ten);
271 272
  }

273
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
274 275 276 277 278 279 280 281 282 283 284 285 286 287
};

template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
288
  __device__ __forceinline__ T operator()(const T arg_x) const {
289
    MPType x = static_cast<MPType>(arg_x);
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
309 310
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
311 312
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
313 314 315 316 317
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

318 319 320
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
321 322 323 324 325 326 327 328 329 330 331 332 333 334
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
335 336
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
337 338
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
339 340 341 342 343 344
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

345
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
346 347 348 349 350 351 352 353 354 355 356 357 358 359
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
360 361
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
362 363
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
364 365 366
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
367
    return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
368 369
  }

370
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
371 372 373 374 375 376 377
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
378
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
379 380
    T temp = one + abs(x);
    return dout / (temp * temp);
381 382
  }

383
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
384 385 386 387 388 389 390 391 392 393 394 395
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
396
  __device__ __forceinline__ T operator()(const T x) const {
397
    T t = static_cast<T>(threshold);
398
    return x <= zero ? zero : (x < t ? x : t);
399 400 401 402 403 404 405 406 407 408 409 410 411
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
412
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
413
    T t = static_cast<T>(threshold);
414
    return (out > zero && out < t) ? dout : zero;
415 416
  }

417 418 419
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
};

template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // hard_sigmoid(x) = 0, when x <= -3
  //                   1, when x >= 3
  //                   x * slope + offset, otherwise
436
  __device__ __forceinline__ T operator()(const T x) const {
437
    T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < one ? temp_max : one;
    return temp_min;
  }
};

template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // dx = (out > 0 && out < 1) ? dout * slope : 0
456
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
457
    return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
458 459
  }

460 461 462
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
463 464 465 466 467 468 469 470 471 472 473 474 475
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
476
  __device__ __forceinline__ T operator()(const T arg_x) const {
477
    MPType x = static_cast<MPType>(arg_x);
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
494 495
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
496 497
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
498 499 500 501 502 503 504 505
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

506
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
507 508
};

509 510 511 512 513 514 515 516 517 518 519 520 521 522
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
  // sp = softplus(x)
  // Inputs: args[0], the input dout
  //         args[1], the input x
523 524
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
525 526 527 528 529 530 531 532 533
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    MPType gsp =
        (x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
    MPType tsp = tanh(sp);
    return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
  }

534
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
535 536
};

537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
552
  __device__ __forceinline__ T operator()(const T x) const {
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
    T t = static_cast<T>(threshold);
    T temp = x + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < t ? temp_max : t;
    return temp_min * x / static_cast<T>(scale);
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  T two = static_cast<T>(2.0f);
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
578
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
579 580 581 582
    T o = static_cast<T>(offset);
    T s = static_cast<T>(scale);
    T temp1 = static_cast<T>(x + o > zero);
    T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
583
    return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
584 585
  }

586
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
587 588
};

589 590 591 592 593 594 595 596 597 598 599 600
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
601
  __device__ __forceinline__ T operator()(const T arg_x) const {
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
624 625
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
626 627 628 629 630 631 632 633 634 635 636 637 638
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

639
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
640 641
};

642
template <typename DeviceContext, typename Functor>
643
class ActivationCudaKernel
644 645 646
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
647 648
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
649
    framework::Tensor* out = nullptr;
650 651 652 653 654 655
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
656 657
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
658
      *attr.second = ctx.Attr<float>(attr.first);
659
    }
660 661
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
662 663 664 665
  }
};

template <typename DeviceContext, typename Functor>
666
class ActivationGradCudaKernel
667 668 669
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
670
  void Compute(const framework::ExecutionContext& ctx) const override {
671 672 673
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
674
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
675
                                                    &d_x);
676 677 678 679 680 681 682 683 684 685
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
686

687 688
    if (static_cast<int>(Functor::FwdDeps()) ==
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
689
      // Only need forward output Out
690
      ins.push_back(out);
691 692
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
693
    } else if (static_cast<int>(Functor::FwdDeps()) ==
694
               static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
695
      // Only need forward input X
696
      ins.push_back(x);
697 698
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
699
    } else {
700 701
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
702 703 704 705
    }
  }
};

706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
USE_PHI_FUNCTOR(CudaCos)
USE_PHI_FUNCTOR(CudaTan)
USE_PHI_FUNCTOR(CudaAcos)
USE_PHI_FUNCTOR(CudaSin)
USE_PHI_FUNCTOR(CudaAsin)
USE_PHI_FUNCTOR(CudaAtan)
USE_PHI_FUNCTOR(CudaSinh)
USE_PHI_FUNCTOR(CudaCosh)
USE_PHI_FUNCTOR(CudaAsinh)
USE_PHI_FUNCTOR(CudaAcosh)
USE_PHI_FUNCTOR(CudaAtanh)
USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
Y
YuanRisheng 已提交
721 722 723 724 725 726 727 728 729
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)

template <typename T>
using CudaELUGradNegativeAlphaFunctor =
    phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
730

731 732 733
}  // namespace operators
}  // namespace paddle

734
namespace ops = paddle::operators;
735 736
namespace plat = paddle::platform;

737 738
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
739
  REGISTER_OP_CUDA_KERNEL(                                                     \
740 741 742 743 744
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
745 746 747
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
748
  REGISTER_OP_CUDA_KERNEL(                                                     \
749 750 751 752 753 754
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
755 756 757
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
758

759 760 761 762 763 764 765 766 767 768 769 770
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
771 772 773
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
774 775 776 777 778 779 780 781 782 783 784
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
785 786 787
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
788

D
Double_V 已提交
789 790
/* ========================================================================== */

791 792 793 794 795 796 797 798 799 800 801 802 803
/* ======================== celu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
                                CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                                              ops::CELUGradGradFunctor<float>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<double>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

804 805 806 807 808 809 810 811 812 813 814 815
/* ===========================    sigmoid register  ============================
 */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
                                CudaSigmoidGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    sigmoid_grad_grad,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<float>>,
    ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<double>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
816 817 818
                                 ops::SigmoidGradGradFunctor<plat::float16>>,
    ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
                                 ops::SigmoidGradGradFunctor<plat::bfloat16>>);
819 820 821 822 823 824 825 826

REGISTER_OP_CUDA_KERNEL(
    sigmoid_triple_grad,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<float>>,
    ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
                                 ops::SigmoidTripleGradFunctor<double>>,
    ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
827 828 829 830
                                 ops::SigmoidTripleGradFunctor<plat::float16>>,
    ops::SigmoidTripleGradKernel<
        plat::CUDADeviceContext,
        ops::SigmoidTripleGradFunctor<plat::bfloat16>>);
831 832
/* ========================================================================== */

L
lvmengsi 已提交
833
/* ===========================   sqrt register  ============================= */
834 835
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
                                CudaSqrtGradFunctor);
L
lvmengsi 已提交
836 837 838 839 840 841 842 843

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
844 845 846
                              ops::SqrtGradGradFunctor<plat::float16>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::bfloat16>>);
L
lvmengsi 已提交
847 848
/* ========================================================================== */

W
whs 已提交
849 850
/* ===========================   rsqrt register  =============================
 */
851 852
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
                                CudaRsqrtGradFunctor);
W
whs 已提交
853 854 855 856 857 858 859 860 861 862 863

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

864
/* ===========================  square register  ============================ */
865 866
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
                                    CudaSquareGradFunctor);
867 868 869 870 871 872 873 874

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
875
                                ops::SquareGradGradFunctor<plat::float16>>,
876 877
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
                                ops::SquareGradGradFunctor<plat::bfloat16>>,
878 879 880 881
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
882
/* ========================================================================== */
883 884 885 886 887

/* ==========================   pow register  ============================ */
REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
888 889
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
890 891 892 893 894
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
895 896
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
897 898 899
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
900

W
wangzhen38 已提交
901 902 903 904 905 906 907 908 909 910
/* ==========================   logit register  ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    logit_grad,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::LogitGradKernel<paddle::platform::CUDADeviceContext,
                         paddle::platform::float16>);
/* ========================================================================== */

911 912
/* ==========================   exp register  ============================ */
REGISTER_OP_CUDA_KERNEL(
913 914 915 916
    exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                   ops::CudaExpFunctor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<double>>,
917 918
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
919 920
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpFunctor<plat::float16>>);
921
REGISTER_OP_CUDA_KERNEL(
922 923 924 925 926 927 928 929 930 931
    exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                            ops::CudaExpGradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpGradFunctor<plat::float16>>);
932 933
/* ========================================================================== */

R
ronnywang 已提交
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
/* ==========================   expm1 register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
                                     ops::CudaExpm1Functor<float>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<double>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                              ops::CudaExpm1GradFunctor<float>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */

952
/* ==========================  Log register ==================================*/
953
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, CudaLogFunctor, CudaLogGradFunctor);
954 955 956 957 958 959 960 961 962

REGISTER_OP_CUDA_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
963

964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor,                      \
          CudaLogSigmoidGradFunctor);                                         \
  __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor,                      \
          CudaSoftShrinkGradFunctor);                                         \
  __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor);                  \
  __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor);               \
  __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor);               \
  __macro(reciprocal, Reciprocal, CudaReciprocalFunctor,                      \
          CudaReciprocalGradFunctor);                                         \
  __macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor);              \
  __macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor);                  \
  __macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor);              \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor);              \
  __macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor);  \
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);  \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
  __macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor,                     \
          CudaTanhShrinkGradFunctor);                                         \
  __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor,                     \
          CudaHardShrinkGradFunctor);                                         \
  __macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor,                  \
          CudaHardSigmoidGradFunctor);                                        \
  __macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor);              \
989
  __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor);                  \
990 991 992
  __macro(hard_swish, HardSwish, CudaHardSwishFunctor,                        \
          CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
993 994

#ifdef PADDLE_WITH_XPU_KP
995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
REGISTER_OP_KERNEL(
    brelu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaBReluFunctor<float>>);
REGISTER_OP_KERNEL(
    brelu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaBReluGradFunctor<float>>);

REGISTER_OP_KERNEL(ceil, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCeilFunctor<float>>);
REGISTER_OP_KERNEL(
    ceil_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
    celu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaCELUGradFunctor<float>>);

REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaELUFunctor<float>>);
REGISTER_OP_KERNEL(
    elu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaELUGradFunctor<float>>);

REGISTER_OP_KERNEL(exp, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaExpFunctor<float>>);
REGISTER_OP_KERNEL(
    exp_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaExpGradFunctor<float>>);

REGISTER_OP_KERNEL(floor, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaFloorFunctor<float>>);
REGISTER_OP_KERNEL(
    floor_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_shrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_shrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_sigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(hard_swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaHardSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    leaky_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaLeakyReluFunctor<float>>);
REGISTER_OP_KERNEL(
    leaky_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaLeakyReluGradFunctor<float>>);

REGISTER_OP_KERNEL(log, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLogFunctor<float>>);
REGISTER_OP_KERNEL(
    log_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogGradFunctor<float>>);

REGISTER_OP_KERNEL(log1p, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLog1pFunctor<float>>);
REGISTER_OP_KERNEL(
    log1p_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLog1pGradFunctor<float>>);

REGISTER_OP_KERNEL(
    logsigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaLogSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    logsigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(
    reciprocal, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaReciprocalFunctor<float>>);
REGISTER_OP_KERNEL(
    reciprocal_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaReciprocalGradFunctor<float>>);

REGISTER_OP_KERNEL(
    relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaReluFunctor<float>>);
REGISTER_OP_KERNEL(
    relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaReluGradFunctor<float>>);

REGISTER_OP_KERNEL(relu6, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaRelu6Functor<float>>);
REGISTER_OP_KERNEL(
    relu6_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaRelu6GradFunctor<float>>);

REGISTER_OP_KERNEL(sigmoid, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(silu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSiluFunctor<float>>);
REGISTER_OP_KERNEL(
    silu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSiluGradFunctor<float>>);

REGISTER_OP_KERNEL(soft_relu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftReluFunctor<float>>);
REGISTER_OP_KERNEL(
    soft_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftReluGradFunctor<float>>);

REGISTER_OP_KERNEL(softplus, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftplusFunctor<float>>);
REGISTER_OP_KERNEL(
    softplus_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftplusGradFunctor<float>>);

REGISTER_OP_KERNEL(
    softshrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaSoftShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    softshrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(softsign, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftsignFunctor<float>>);
REGISTER_OP_KERNEL(
    softsign_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftsignGradFunctor<float>>);

REGISTER_OP_KERNEL(sqrt, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSqrtFunctor<float>>);
REGISTER_OP_KERNEL(
    sqrt_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSqrtGradFunctor<float>>);

REGISTER_OP_KERNEL(square, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSquareFunctor<float>>);
REGISTER_OP_KERNEL(
    square_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSquareGradFunctor<float>>);

REGISTER_OP_KERNEL(swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    thresholded_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaThresholdedReluFunctor<float>>);
REGISTER_OP_KERNEL(
    thresholded_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaThresholdedReluGradFunctor<float>>);
1211 1212

#endif  // PADDLE_WITH_XPU_KP