activation_op.kps 23.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
15
#include "paddle/fluid/platform/bfloat16.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
17 18
#include "paddle/phi/kernels/funcs/activation_functor.h"

19 20 21
namespace paddle {
namespace operators {

22 23 24 25 26 27 28 29 30 31 32 33
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
  // threshold should not be negative
34
  __device__ __forceinline__ T operator()(const T arg_x) const {
35
    MPType x = static_cast<MPType>(arg_x);
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    MPType t = static_cast<MPType>(threshold);
    MPType temp_min = x < t ? x : t;
    MPType temp_max = temp_min > -t ? temp_min : -t;
    return static_cast<T>(log(one + exp(temp_max)));
  }
};

template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename details::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
  // threshold should not be negative
55 56
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
57 58
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
59 60 61 62 63
    MPType t = static_cast<MPType>(threshold);
    return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
                                 : static_cast<T>(0.0f);
  }

64 65 66
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
67 68
};

69
template <typename T>
70 71
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);
72

73 74 75
  // softsign(x) = x / (1 + abs(x))
  __device__ __forceinline__ T operator()(const T x) const {
    return x / (one + abs(x));
76 77 78 79
  }
};

template <typename T>
80 81
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);
82

83 84 85 86
  // dx = dout / (1 + abs(x))^2
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    T temp = one + abs(x);
    return dout / (temp * temp);
87 88
  }

89
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
90 91
};

92
template <typename DeviceContext, typename Functor>
93
class ActivationCudaKernel
94 95 96
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
97 98
  void Compute(const framework::ExecutionContext& ctx) const override {
    const framework::Tensor* x = nullptr;
99
    framework::Tensor* out = nullptr;
100 101 102 103 104 105
    ExtractActivationTensor(ctx, &x, &out);
    out->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    std::vector<const framework::Tensor*> ins = {x};
    std::vector<framework::Tensor*> outs = {out};
    auto functor = Functor();
106 107
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
108
      *attr.second = ctx.Attr<float>(attr.first);
109
    }
110 111
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
        dev_ctx, ins, &outs, functor);
112 113 114 115
  }
};

template <typename DeviceContext, typename Functor>
116
class ActivationGradCudaKernel
117 118 119
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
120
  void Compute(const framework::ExecutionContext& ctx) const override {
121 122 123
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
124 125
    ExtractActivationGradTensor<Functor::FwdDeps()>(
        ctx, &x, &out, &d_out, &d_x);
126 127 128 129 130 131 132 133 134 135
    d_x->mutable_data<T>(ctx.GetPlace());
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto functor = Functor();
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = ctx.Attr<float>(attr.first);
    }

    std::vector<const framework::Tensor*> ins = {d_out};
    std::vector<framework::Tensor*> outs = {d_x};
136

137 138
    if (static_cast<int>(Functor::FwdDeps()) ==
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
139
      // Only need forward output Out
140
      ins.push_back(out);
141 142
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, functor);
143
    } else if (static_cast<int>(Functor::FwdDeps()) ==
144
               static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
145
      // Only need forward input X
146
      ins.push_back(x);
147 148
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, functor);
149
    } else {
150 151
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
          dev_ctx, ins, &outs, functor);
152 153 154 155
    }
  }
};

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
USE_PHI_FUNCTOR(CudaCos)
USE_PHI_FUNCTOR(CudaTan)
USE_PHI_FUNCTOR(CudaAcos)
USE_PHI_FUNCTOR(CudaSin)
USE_PHI_FUNCTOR(CudaAsin)
USE_PHI_FUNCTOR(CudaAtan)
USE_PHI_FUNCTOR(CudaSinh)
USE_PHI_FUNCTOR(CudaCosh)
USE_PHI_FUNCTOR(CudaAsinh)
USE_PHI_FUNCTOR(CudaAcosh)
USE_PHI_FUNCTOR(CudaAtanh)
USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu)
171
USE_PHI_FUNCTOR(CudaRelu6)
Y
YuanRisheng 已提交
172 173 174 175 176
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
Y
YuanRisheng 已提交
177 178 179
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
180 181 182 183
USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p)
Y
YuanRisheng 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197
USE_PHI_FUNCTOR(CudaSwish)
USE_PHI_FUNCTOR(CudaHardSwish)

template <typename T>
using CudaRoundFunctor = phi::funcs::CudaRoundFunctor<T>;

template <typename T>
using CudaFloorFunctor = phi::funcs::CudaFloorFunctor<T>;

template <typename T>
using CudaCeilFunctor = phi::funcs::CudaCeilFunctor<T>;

template <typename T>
using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor<T>;
Y
YuanRisheng 已提交
198

199 200 201 202 203 204 205 206 207 208
USE_PHI_FUNCTOR(CudaExp)
USE_PHI_FUNCTOR(CudaExpm1)
USE_PHI_FUNCTOR(CudaMish)
USE_PHI_FUNCTOR(CudaSTanh)
USE_PHI_FUNCTOR(CudaReciprocal)
USE_PHI_FUNCTOR(CudaSquare)
USE_PHI_FUNCTOR(CudaSqrt)
USE_PHI_FUNCTOR(CudaRsqrt)
USE_PHI_FUNCTOR(CudaSoftplus)

Y
YuanRisheng 已提交
209 210 211
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
    phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
212

213 214 215
}  // namespace operators
}  // namespace paddle

216
namespace ops = paddle::operators;
217 218
namespace plat = paddle::platform;

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
#define REGISTER_ACTIVATION_CUDA_KERNEL(                               \
    act_type, op_name, functor, grad_functor)                          \
  REGISTER_OP_CUDA_KERNEL(                                             \
      act_type,                                                        \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<float>>,                  \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<double>>,                 \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,               \
                                ops::functor<plat::float16>>,          \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,               \
                                ops::functor<plat::bfloat16>>);        \
  REGISTER_OP_CUDA_KERNEL(                                             \
      act_type##_grad,                                                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<float>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<double>>,        \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<plat::float16>>, \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
240
                                    ops::grad_functor<plat::bfloat16>>);
241

242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(                           \
    act_type, op_name, functor, grad_functor)                          \
  REGISTER_OP_CUDA_KERNEL(                                             \
      act_type,                                                        \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<float>>,                  \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<double>>,                 \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<int>>,                    \
      ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,   \
                                ops::functor<int64_t>>,                \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,               \
                                ops::functor<plat::float16>>,          \
      ops::ActivationCudaKernel<plat::CUDADeviceContext,               \
                                ops::functor<plat::bfloat16>>);        \
  REGISTER_OP_CUDA_KERNEL(                                             \
      act_type##_grad,                                                 \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<float>>,         \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<double>>,        \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<int>>,           \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<int64_t>>,       \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
                                    ops::grad_functor<plat::float16>>, \
      ops::ActivationGradCudaKernel<plat::CUDADeviceContext,           \
271
                                    ops::grad_functor<plat::bfloat16>>);
272

273
REGISTER_OP_CUDA_KERNEL(
274 275 276
    relu6,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaRelu6Functor<float>>,
Y
YuanRisheng 已提交
277 278 279 280 281 282 283 284 285 286
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaRelu6Functor<double>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaRelu6Functor<int>>,
    ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
                              ops::CudaRelu6Functor<int64_t>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaRelu6Functor<plat::float16>>,
    ops::ActivationCudaKernel<plat::CUDADeviceContext,
                              ops::CudaRelu6Functor<plat::bfloat16>>);
W
whs 已提交
287
REGISTER_OP_CUDA_KERNEL(
288 289 290
    relu6_grad,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<float>>,
Y
YuanRisheng 已提交
291 292 293 294 295 296 297 298 299 300
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<double>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<int>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<int64_t>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<plat::float16>>,
    ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
                                  ops::CudaRelu6GradFunctor<plat::bfloat16>>);
R
ronnywang 已提交
301

302 303
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro)                                  \
  __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
304 305
  __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);

306
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
307 308

#ifdef PADDLE_WITH_XPU_KP
309
REGISTER_OP_KERNEL(
310 311 312
    brelu,
    KP,
    plat::XPUPlace,
313 314 315
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaBReluFunctor<float>>);
REGISTER_OP_KERNEL(
316 317 318
    brelu_grad,
    KP,
    plat::XPUPlace,
319 320 321
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaBReluGradFunctor<float>>);

322 323 324
REGISTER_OP_KERNEL(ceil,
                   KP,
                   plat::XPUPlace,
325 326 327
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaCeilFunctor<float>>);
REGISTER_OP_KERNEL(
328 329 330
    ceil_grad,
    KP,
    plat::XPUPlace,
331 332 333
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

Y
YuanRisheng 已提交
334
REGISTER_OP_KERNEL(
335 336 337
    celu,
    KP,
    plat::XPUPlace,
Y
YuanRisheng 已提交
338 339
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaCELUFunctor<float>>);
340
REGISTER_OP_KERNEL(
341 342 343
    celu_grad,
    KP,
    plat::XPUPlace,
344
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
Y
YuanRisheng 已提交
345
                                  phi::funcs::CudaCELUGradFunctor<float>>);
346

347 348 349
REGISTER_OP_KERNEL(elu,
                   KP,
                   plat::XPUPlace,
350 351 352
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaELUFunctor<float>>);
REGISTER_OP_KERNEL(
353 354 355
    elu_grad,
    KP,
    plat::XPUPlace,
356 357 358
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaELUGradFunctor<float>>);

359 360 361
REGISTER_OP_KERNEL(exp,
                   KP,
                   plat::XPUPlace,
362 363 364
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaExpFunctor<float>>);
REGISTER_OP_KERNEL(
365 366 367
    exp_grad,
    KP,
    plat::XPUPlace,
368 369 370
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaExpGradFunctor<float>>);

371 372 373
REGISTER_OP_KERNEL(floor,
                   KP,
                   plat::XPUPlace,
374 375 376
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaFloorFunctor<float>>);
REGISTER_OP_KERNEL(
377 378 379
    floor_grad,
    KP,
    plat::XPUPlace,
380 381 382 383
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaZeroGradFunctor<float>>);

REGISTER_OP_KERNEL(
384 385 386
    hard_shrink,
    KP,
    plat::XPUPlace,
387 388 389
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
390 391 392
    hard_shrink_grad,
    KP,
    plat::XPUPlace,
393 394 395 396
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardShrinkGradFunctor<float>>);

REGISTER_OP_KERNEL(
397 398 399
    hard_sigmoid,
    KP,
    plat::XPUPlace,
400 401 402
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaHardSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
403 404 405
    hard_sigmoid_grad,
    KP,
    plat::XPUPlace,
406 407 408
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSigmoidGradFunctor<float>>);

409 410 411
REGISTER_OP_KERNEL(hard_swish,
                   KP,
                   plat::XPUPlace,
412 413 414
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaHardSwishFunctor<float>>);
REGISTER_OP_KERNEL(
415 416 417
    hard_swish_grad,
    KP,
    plat::XPUPlace,
418 419 420 421
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaHardSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
422 423 424
    leaky_relu,
    KP,
    plat::XPUPlace,
425 426 427
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaLeakyReluFunctor<float>>);
REGISTER_OP_KERNEL(
428 429 430
    leaky_relu_grad,
    KP,
    plat::XPUPlace,
431 432 433
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaLeakyReluGradFunctor<float>>);

434 435 436
REGISTER_OP_KERNEL(log,
                   KP,
                   plat::XPUPlace,
437 438 439
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLogFunctor<float>>);
REGISTER_OP_KERNEL(
440 441 442
    log_grad,
    KP,
    plat::XPUPlace,
443 444 445
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogGradFunctor<float>>);

446 447 448
REGISTER_OP_KERNEL(log1p,
                   KP,
                   plat::XPUPlace,
449 450 451
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaLog1pFunctor<float>>);
REGISTER_OP_KERNEL(
452 453 454
    log1p_grad,
    KP,
    plat::XPUPlace,
455 456 457 458
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLog1pGradFunctor<float>>);

REGISTER_OP_KERNEL(
459 460 461
    logsigmoid,
    KP,
    plat::XPUPlace,
462 463 464
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaLogSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
465 466 467
    logsigmoid_grad,
    KP,
    plat::XPUPlace,
468 469 470 471
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaLogSigmoidGradFunctor<float>>);

REGISTER_OP_KERNEL(
472 473 474
    reciprocal,
    KP,
    plat::XPUPlace,
475 476 477
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaReciprocalFunctor<float>>);
REGISTER_OP_KERNEL(
478 479 480
    reciprocal_grad,
    KP,
    plat::XPUPlace,
481 482 483 484
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaReciprocalGradFunctor<float>>);

REGISTER_OP_KERNEL(
485 486 487
    relu,
    KP,
    plat::XPUPlace,
488 489 490
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              phi::funcs::CudaReluFunctor<float>>);
REGISTER_OP_KERNEL(
491 492 493
    relu_grad,
    KP,
    plat::XPUPlace,
494 495 496
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  phi::funcs::CudaReluGradFunctor<float>>);

497 498 499
REGISTER_OP_KERNEL(relu6,
                   KP,
                   plat::XPUPlace,
500 501 502
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaRelu6Functor<float>>);
REGISTER_OP_KERNEL(
503 504 505
    relu6_grad,
    KP,
    plat::XPUPlace,
506 507 508
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaRelu6GradFunctor<float>>);

509 510 511
REGISTER_OP_KERNEL(sigmoid,
                   KP,
                   plat::XPUPlace,
512 513 514
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSigmoidFunctor<float>>);
REGISTER_OP_KERNEL(
515 516 517
    sigmoid_grad,
    KP,
    plat::XPUPlace,
518 519 520
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSigmoidGradFunctor<float>>);

521 522 523
REGISTER_OP_KERNEL(silu,
                   KP,
                   plat::XPUPlace,
524 525 526
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSiluFunctor<float>>);
REGISTER_OP_KERNEL(
527 528 529
    silu_grad,
    KP,
    plat::XPUPlace,
530 531 532
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSiluGradFunctor<float>>);

533 534 535
REGISTER_OP_KERNEL(soft_relu,
                   KP,
                   plat::XPUPlace,
536 537 538
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftReluFunctor<float>>);
REGISTER_OP_KERNEL(
539 540 541
    soft_relu_grad,
    KP,
    plat::XPUPlace,
542 543 544
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftReluGradFunctor<float>>);

545 546 547
REGISTER_OP_KERNEL(softplus,
                   KP,
                   plat::XPUPlace,
548 549 550
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftplusFunctor<float>>);
REGISTER_OP_KERNEL(
551 552 553
    softplus_grad,
    KP,
    plat::XPUPlace,
554 555 556 557
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftplusGradFunctor<float>>);

REGISTER_OP_KERNEL(
558 559 560
    softshrink,
    KP,
    plat::XPUPlace,
561 562 563
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaSoftShrinkFunctor<float>>);
REGISTER_OP_KERNEL(
564 565 566
    softshrink_grad,
    KP,
    plat::XPUPlace,
567 568 569
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftShrinkGradFunctor<float>>);

570 571 572
REGISTER_OP_KERNEL(softsign,
                   KP,
                   plat::XPUPlace,
573 574 575
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSoftsignFunctor<float>>);
REGISTER_OP_KERNEL(
576 577 578
    softsign_grad,
    KP,
    plat::XPUPlace,
579 580 581
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSoftsignGradFunctor<float>>);

582 583 584
REGISTER_OP_KERNEL(sqrt,
                   KP,
                   plat::XPUPlace,
585 586 587
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSqrtFunctor<float>>);
REGISTER_OP_KERNEL(
588 589 590
    sqrt_grad,
    KP,
    plat::XPUPlace,
591 592 593
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSqrtGradFunctor<float>>);

594 595 596
REGISTER_OP_KERNEL(square,
                   KP,
                   plat::XPUPlace,
597 598 599
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSquareFunctor<float>>);
REGISTER_OP_KERNEL(
600 601 602
    square_grad,
    KP,
    plat::XPUPlace,
603 604 605
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSquareGradFunctor<float>>);

606 607 608
REGISTER_OP_KERNEL(swish,
                   KP,
                   plat::XPUPlace,
609 610 611
                   ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                                             ops::CudaSwishFunctor<float>>);
REGISTER_OP_KERNEL(
612 613 614
    swish_grad,
    KP,
    plat::XPUPlace,
615 616 617 618
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaSwishGradFunctor<float>>);

REGISTER_OP_KERNEL(
619 620 621
    thresholded_relu,
    KP,
    plat::XPUPlace,
622 623 624
    ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
                              ops::CudaThresholdedReluFunctor<float>>);
REGISTER_OP_KERNEL(
625 626 627
    thresholded_relu_grad,
    KP,
    plat::XPUPlace,
628 629
    ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
                                  ops::CudaThresholdedReluGradFunctor<float>>);
630 631

#endif  // PADDLE_WITH_XPU_KP