activation_op.kps 28.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/platform/bfloat16.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
17

18 19
#include "paddle/phi/kernels/funcs/activation_functor.h"

20 21 22
namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30 31 32 33 34
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
35
  __device__ __forceinline__ T operator()(const T arg_x) const {
36
    MPType x = static_cast<MPType>(arg_x);
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
56 57
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
58 59
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
60 61 62 63 64
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

65 66 67
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
68 69 70 71 72 73 74 75 76 77 78 79
};

template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
80
  __device__ __forceinline__ T operator()(const T x) const {
81
    T t = static_cast<T>(threshold);
82
    return x <= zero ? zero : (x < t ? x : t);
83 84 85 86 87 88 89 90 91 92 93 94 95
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > 0 && out < t) ? dout : 0
96
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
97
    T t = static_cast<T>(threshold);
98
    return (out > zero && out < t) ? dout : zero;
99 100
  }

101 102 103
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
104 105
};

106
template <typename T>
107 108
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);
109

110 111 112
  // softsign(x) = x / (1 + abs(x))
  __device__ __forceinline__ T operator()(const T x) const {
    return x / (one + abs(x));
113 114 115 116
  }
};

template <typename T>
117 118
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);
119

120 121 122 123
  // dx = dout / (1 + abs(x))^2
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    T temp = one + abs(x);
    return dout / (temp * temp);
124 125
  }

126
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
127 128
};

129 130 131 132 133 134 135 136 137 138 139 140
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename details::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
141
  __device__ __forceinline__ T operator()(const T arg_x) const {
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
164 165
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
166 167 168 169 170 171 172 173 174 175 176 177 178
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

179
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
180 181
};

182
template <typename DeviceContext, typename Functor>
183
class ActivationCudaKernel
184 185 186
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
187 188
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
189
    framework::Tensor* out = nullptr;
190 191 192 193 194 195
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
196 197
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
198
      *attr.second = ctx.Attr<float>(attr.first);
199
    }
200 201
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                              &outs, functor);
202 203 204 205
  }
};

template <typename DeviceContext, typename Functor>
206
class ActivationGradCudaKernel
207 208 209
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
210
  void Compute(const framework::ExecutionContext& ctx) const override {
211 212 213
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
214
    ExtractActivationGradTensor<Functor::FwdDeps()>(ctx, &x, &out, &d_out,
215
                                                    &d_x);
216 217 218 219 220 221 222 223 224 225
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
226

227 228
    if (static_cast<int>(Functor::FwdDeps()) ==
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
229
      // Only need forward output Out
230
      ins.push_back(out);
231 232
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
233
    } else if (static_cast<int>(Functor::FwdDeps()) ==
234
               static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
235
      // Only need forward input X
236
      ins.push_back(x);
237 238
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
239
    } else {
240 241
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
                                                                &outs, functor);
242 243 244 245
    }
  }
};

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
USE_PHI_FUNCTOR(CudaCos)
USE_PHI_FUNCTOR(CudaTan)
USE_PHI_FUNCTOR(CudaAcos)
USE_PHI_FUNCTOR(CudaSin)
USE_PHI_FUNCTOR(CudaAsin)
USE_PHI_FUNCTOR(CudaAtan)
USE_PHI_FUNCTOR(CudaSinh)
USE_PHI_FUNCTOR(CudaCosh)
USE_PHI_FUNCTOR(CudaAsinh)
USE_PHI_FUNCTOR(CudaAcosh)
USE_PHI_FUNCTOR(CudaAtanh)
USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
Y
YuanRisheng 已提交
261 262 263 264 265
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
Y
YuanRisheng 已提交
266 267 268
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
269 270 271 272
USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)
Y
YuanRisheng 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286
USE_PHI_FUNCTOR(CudaSwish)
USE_PHI_FUNCTOR(CudaHardSwish)

template <typename T>
using CudaRoundFunctor = phi::funcs::CudaRoundFunctor<T>;

template <typename T>
using CudaFloorFunctor = phi::funcs::CudaFloorFunctor<T>;

template <typename T>
using CudaCeilFunctor = phi::funcs::CudaCeilFunctor<T>;

template <typename T>
using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor<T>;
Y
YuanRisheng 已提交
287

288 289 290 291 292 293 294 295 296 297
USE_PHI_FUNCTOR(CudaExp)
USE_PHI_FUNCTOR(CudaExpm1)
USE_PHI_FUNCTOR(CudaMish)
USE_PHI_FUNCTOR(CudaSTanh)
USE_PHI_FUNCTOR(CudaReciprocal)
USE_PHI_FUNCTOR(CudaSquare)
USE_PHI_FUNCTOR(CudaSqrt)
USE_PHI_FUNCTOR(CudaRsqrt)
USE_PHI_FUNCTOR(CudaSoftplus)

Y
YuanRisheng 已提交
298 299 300
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
    phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
301

302 303 304
}  // namespace operators
}  // namespace paddle

305
namespace ops = paddle::operators;
306 307
namespace plat = paddle::platform;

308 309
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,            \
                                        grad_functor)                          \
310
  REGISTER_OP_CUDA_KERNEL(                                                     \
311 312 313 314 315
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
316 317 318
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
319
  REGISTER_OP_CUDA_KERNEL(                                                     \
320 321 322 323 324 325
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
326 327 328
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
329

330 331 332 333 334 335 336 337 338 339 340 341
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor,        \
                                            grad_functor)                      \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
                                          ops::functor<float>>,                \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<double>>,                         \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int>>,                            \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,           \
                                ops::functor<int64_t>>,                        \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
342 343 344
                                ops::functor<plat::float16>>,                  \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,                       \
                                ops::functor<plat::bfloat16>>);                \
345 346 347 348 349 350 351 352 353 354 355
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad,                                                         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<float>>,                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<double>>,                \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int>>,                   \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<int64_t>>,               \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
356 357 358
                                    ops::grad_functor<plat::float16>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,                   \
                                    ops::grad_functor<plat::bfloat16>>);
359

D
Double_V 已提交
360 361
/* ========================================================================== */

362 363 364 365 366 367 368 369 370 371 372 373 374
/* ======================== celu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
                                CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                                              ops::CELUGradGradFunctor<float>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<double>>,
    ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
                              ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

L
lvmengsi 已提交
375 376 377 378 379 380 381 382 383
/* ===========================   sqrt register  ============================= */

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
384 385 386
                              ops::SqrtGradGradFunctor<plat::float16>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::bfloat16>>);
L
lvmengsi 已提交
387 388
/* ========================================================================== */

W
whs 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401
/* ===========================   rsqrt register  =============================
 */

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

402 403 404 405 406 407 408 409 410
/* ===========================  square register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
411
                                ops::SquareGradGradFunctor<plat::float16>>,
412 413
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
                                ops::SquareGradGradFunctor<plat::bfloat16>>,
414 415 416 417
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
418
/* ========================================================================== */
419

W
wangzhen38 已提交
420 421 422 423
/* ==========================   logit register  ============================ */
namespace ops = paddle::operators;
/* ========================================================================== */

424 425 426
/* ==========================   exp register  ============================ */
/* ========================================================================== */

R
ronnywang 已提交
427 428 429
/* ==========================   expm1 register  ============================ */
/* ========================================================================== */

430 431 432
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
  __macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor);              \
433 434
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);

435
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
436 437

#ifdef PADDLE_WITH_XPU_KP
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
REGISTER_OP_KERNEL(
    brelu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaBReluFunctor<float>>);
REGISTER_OP_KERNEL(
    brelu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaBReluGradFunctor<float>>);

REGISTER_OP_KERNEL(ceil, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCeilFunctor<float>>);
REGISTER_OP_KERNEL(
    ceil_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
    celu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaCELUGradFunctor<float>>);

REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaELUFunctor<float>>);
REGISTER_OP_KERNEL(
    elu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaELUGradFunctor<float>>);

REGISTER_OP_KERNEL(exp, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaExpFunctor<float>>);
REGISTER_OP_KERNEL(
    exp_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaExpGradFunctor<float>>);

REGISTER_OP_KERNEL(floor, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaFloorFunctor<float>>);
REGISTER_OP_KERNEL(
    floor_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_shrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_shrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(
    hard_sigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(hard_swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaHardSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    hard_swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    leaky_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaLeakyReluFunctor<float>>);
REGISTER_OP_KERNEL(
    leaky_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaLeakyReluGradFunctor<float>>);

REGISTER_OP_KERNEL(log, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLogFunctor<float>>);
REGISTER_OP_KERNEL(
    log_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogGradFunctor<float>>);

REGISTER_OP_KERNEL(log1p, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLog1pFunctor<float>>);
REGISTER_OP_KERNEL(
    log1p_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLog1pGradFunctor<float>>);

REGISTER_OP_KERNEL(
    logsigmoid, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaLogSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    logsigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(
    reciprocal, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaReciprocalFunctor<float>>);
REGISTER_OP_KERNEL(
    reciprocal_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaReciprocalGradFunctor<float>>);

REGISTER_OP_KERNEL(
    relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaReluFunctor<float>>);
REGISTER_OP_KERNEL(
    relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaReluGradFunctor<float>>);

REGISTER_OP_KERNEL(relu6, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaRelu6Functor<float>>);
REGISTER_OP_KERNEL(
    relu6_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaRelu6GradFunctor<float>>);

REGISTER_OP_KERNEL(sigmoid, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
    sigmoid_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(silu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSiluFunctor<float>>);
REGISTER_OP_KERNEL(
    silu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSiluGradFunctor<float>>);

REGISTER_OP_KERNEL(soft_relu, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftReluFunctor<float>>);
REGISTER_OP_KERNEL(
    soft_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftReluGradFunctor<float>>);

REGISTER_OP_KERNEL(softplus, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftplusFunctor<float>>);
REGISTER_OP_KERNEL(
    softplus_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftplusGradFunctor<float>>);

REGISTER_OP_KERNEL(
    softshrink, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaSoftShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
    softshrink_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(softsign, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftsignFunctor<float>>);
REGISTER_OP_KERNEL(
    softsign_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftsignGradFunctor<float>>);

REGISTER_OP_KERNEL(sqrt, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSqrtFunctor<float>>);
REGISTER_OP_KERNEL(
    sqrt_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSqrtGradFunctor<float>>);

REGISTER_OP_KERNEL(square, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSquareFunctor<float>>);
REGISTER_OP_KERNEL(
    square_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSquareGradFunctor<float>>);

REGISTER_OP_KERNEL(swish, KP, plat::XPUPlace,
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSwishFunctor<float>>);
REGISTER_OP_KERNEL(
    swish_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
    thresholded_relu, KP, plat::XPUPlace,
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaThresholdedReluFunctor<float>>);
REGISTER_OP_KERNEL(
    thresholded_relu_grad, KP, plat::XPUPlace,
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaThresholdedReluGradFunctor<float>>);
654 655

#endif  // PADDLE_WITH_XPU_KP