activation_op.cu 24.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
11

Y
Yi Wang 已提交
12
#include "paddle/fluid/operators/activation_op.h"
13 14
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
K
Kexin Zhao 已提交
15
#include "paddle/fluid/platform/float16.h"
16

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using float16 = paddle::platform::float16;

template <typename T>
struct CudaVecType {
  using type = T;
  static constexpr int vecsize = 1;
};

template <>
struct CudaVecType<platform::float16> {
  using type = __half2;
  static constexpr int vecsize = 2;
};

template <>
struct CudaVecType<float> {
  using type = float4;
  static constexpr int vecsize = 4;
};

template <typename T>
class BaseGPUFunctor {
 public:
  using ELEMENT_TYPE = T;
45 46 47 48

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
};

/* ========================================================================== */

/* ===========================    relu forward   ============================ */
template <typename T>
class ReluGPUFunctor : public BaseGPUFunctor<T> {
 private:
  T zero_;

 public:
  ReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }

  // for relu forward when T is double
  __device__ __forceinline__ typename CudaVecType<T>::type Compute(
64 65 66 67
      const typename CudaVecType<T>::type in) {
    // relu forward : out = max(x, 0)
    return in > zero_ ? in : zero_;
  }
68 69

  // when num % vecsize != 0 this func will be used
70 71 72
  __device__ __forceinline__ T ComputeRemainder(const T in) {
    // relu forward : out = max(x, 0)
    return in > zero_ ? in : zero_;
73 74 75 76 77
  }
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
78 79 80 81
ReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
  // relu forward : out = max(in, 0)
  return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y),
                     (in.z > zero_) * (in.z), (in.w > zero_) * (in.w));
82 83 84 85
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
86
ReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
87 88 89
// relu forward : out = max(in, 0)
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half2 kzero = __float2half2_rn(0.0f);
90
  return __hmul2(__hgt2(in, kzero), in);
91
#else
92
  const float2 xx = __half22float2(in);
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(xx.x),
                           (xx.y > 0.0f) * static_cast<float>(xx.y));
#endif
}
/* ========================================================================== */

/* ===========================    relu backward   ============================
 */

template <typename T>
class ReluGradGPUFunctor : public BaseGPUFunctor<T> {
 private:
  T zero_;

 public:
  ReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }

  // for relu backward when T is double
  __device__ __forceinline__ typename CudaVecType<T>::type Compute(
112 113 114 115
      const typename CudaVecType<T>::type out,
      const typename CudaVecType<T>::type dout) {
    return out > zero_ ? dout : zero_;
  }
116 117 118 119 120 121 122 123 124 125 126 127

  // when num % vecsize != 0 this func will be used
  __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) {
    // relu backward : dx = out > 0 ? dout : 0
    return out > zero_ ? dout : zero_;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
128 129
ReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type out,
                                   const CudaVecType<float>::type dout) {
130
  // relu backward : dx = out > 0 ? dout : 0;
131 132
  return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y),
                     (out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w));
133 134 135 136
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
137 138
ReluGradGPUFunctor<float16>::Compute(const CudaVecType<float16>::type out,
                                     const CudaVecType<float16>::type dout) {
139 140 141
// relu backward : dx = out > 0 ? dout : 0;
#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half2 kzero = __float2half2_rn(0.0f);
142
  return __hmul2(__hgt2(out, kzero), dout);
143
#else
144 145
  const float2 xx = __half22float2(out);
  const float2 yy = __half22float2(dout);
146 147 148 149 150
  return __floats2half2_rn((xx.x > 0.0f) * static_cast<float>(yy.x),
                           (xx.y > 0.0f) * static_cast<float>(yy.y));
#endif
}

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
/* ========================================================================== */
/* ========================    leaky relu forward    ========================
 */
template <typename T>
class LeakyReluGPUFunctor : public BaseGPUFunctor<T> {
 private:
  T zero_;
  float alpha_;

 public:
  LeakyReluGPUFunctor() { zero_ = static_cast<T>(0.0f); }

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha_}};
  }
  // leakyrelu forward : out = x > 0 ? x : x * alpha
  __device__ __forceinline__ typename CudaVecType<T>::type Compute(
      const typename CudaVecType<T>::type in) {
    return in > zero_ ? in : static_cast<T>(alpha_) * in;
  }

  __device__ __forceinline__ T ComputeRemainder(const T in) {
    // leakyrelu forward : out = x > 0 ? x : x * alpha
    return in > zero_ ? in : static_cast<T>(alpha_) * in;
  }
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGPUFunctor<float>::Compute(const CudaVecType<float>::type in) {
  // leakyrelu forward : out = x > 0 ? x : x * alpha
  return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_,
                     (in.y > zero_) ? (in.y) : (in.y) * alpha_,
                     (in.z > zero_) ? (in.z) : (in.z) * alpha_,
                     (in.w > zero_) ? (in.w) : (in.w) * alpha_);
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type
LeakyReluGPUFunctor<float16>::Compute(const CudaVecType<float16>::type in) {
  // leakyrelu forward : out = x > 0 ? x : x * alpha
  const float2 xx = __half22float2(in);
  return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_,
                           (xx.y > 0.0f) ? xx.y : xx.y * alpha_);
}
/* ========================================================================== */

/* ===========================  leaky relu backward   =======================
 */
template <typename T>
class LeakyReluGradGPUFunctor : public BaseGPUFunctor<T> {
 private:
  T zero_;
  float alpha_;

 public:
  LeakyReluGradGPUFunctor() { zero_ = static_cast<T>(0.0f); }

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha_}};
  }

  // for leaky relu backward when T is double
  __device__ __forceinline__ typename CudaVecType<T>::type Compute(
      const typename CudaVecType<T>::type in,
      const typename CudaVecType<T>::type dout) {
    // leakyrelu backward : dx = x > 0 ? dout : alpha * dout
    return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
  }

  // when num % vecsize != 0 this func will be used
  __device__ __forceinline__ T ComputeRemainder(const T in, const T dout) {
    // leakyrelu backward : dx = x > 0 ? dout : alpha * dout
    return in > zero_ ? dout : static_cast<T>(alpha_) * dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <>
__device__ __forceinline__ CudaVecType<float>::type
LeakyReluGradGPUFunctor<float>::Compute(const CudaVecType<float>::type in,
                                        const CudaVecType<float>::type dout) {
  // leakyrelu backward : dx = x > 0 ? dout : alpha * dout
  return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x),
                     (in.y > zero_) ? (dout.y) : alpha_ * (dout.y),
                     (in.z > zero_) ? (dout.z) : alpha_ * (dout.z),
                     (in.w > zero_) ? (dout.w) : alpha_ * (dout.w));
}

template <>
__device__ __forceinline__ CudaVecType<float16>::type LeakyReluGradGPUFunctor<
    float16>::Compute(const CudaVecType<float16>::type in,
                      const CudaVecType<float16>::type dout) {
  // leakyrelu backward : dx = x > 0 ? dout : alpha * dout
  const float2 xx = __half22float2(in);
  const float2 yy = __half22float2(dout);
  return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x,
                           (xx.y > 0.0f) ? yy.y : alpha_ * yy.y);
}

252 253 254 255 256 257 258 259 260 261 262 263 264 265
/* ========================================================================== */

template <typename T, typename Functor>
__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout,
                                        T* dx, int num, Functor functor) {
  using VecType = typename CudaVecType<T>::type;
  constexpr int vecsize = CudaVecType<T>::vecsize;
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int stride = blockDim.x * gridDim.x;
  int loop = num / vecsize;
  int tail = num % vecsize;
  const VecType* in_forward = reinterpret_cast<const VecType*>(forward_data);
  const VecType* in_dout = reinterpret_cast<const VecType*>(dout);
  VecType* out = reinterpret_cast<VecType*>(dx);
266 267
  VecType forward_vec, dout_vec;
  T in_data, dout_data;
268
  for (int i = idx; i < loop; i += stride) {
269 270 271 272 273 274 275 276
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
    forward_vec = __ldg(in_forward + i);
    dout_vec = __ldg(in_dout + i);
#else
    forward_vec = in_forward[i];
    dout_vec = in_dout[i];
#endif
    out[i] = functor.Compute(forward_vec, dout_vec);
277 278 279
  }

  while (idx == loop && tail) {
280 281 282
    in_data = forward_data[num - tail];
    dout_data = dout[num - tail];
    dx[num - tail] = functor.ComputeRemainder(in_data, dout_data);
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    --tail;
  }
}

template <typename T, typename Functor>
__global__ void ActivationkernelVec(const T* src, T* dst, int num,
                                    Functor functor) {
  constexpr int vecsize = CudaVecType<T>::vecsize;
  using VecType = typename CudaVecType<T>::type;
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int stride = blockDim.x * gridDim.x;
  int loop = num / vecsize;
  int tail = num % vecsize;
  const VecType* in = reinterpret_cast<const VecType*>(src);
  VecType* out = reinterpret_cast<VecType*>(dst);
298
  VecType x_vec;
299
  for (int i = idx; i < loop; i += stride) {
300 301 302 303 304 305
#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350
    x_vec = __ldg(in + i);
#else
    x_vec = in[i];
#endif
    out[i] = functor.Compute(x_vec);
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
  }

  while (idx == loop && tail) {
    dst[num - tail] = functor.ComputeRemainder(src[num - tail]);
    --tail;
  }
}

template <typename DeviceContext, typename Functor>
class ActivationGPUKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* in_x = nullptr;
    framework::Tensor* out = nullptr;
    ExtractActivationTensor(context, &in_x, &out);
    auto& dev_ctx = context.template device_context<DeviceContext>();

    int num = in_x->numel();
    const T* input_data = in_x->data<T>();
    T* output_data = out->mutable_data<T>(dev_ctx.GetPlace(),
                                          static_cast<size_t>(num * sizeof(T)));

    int block = 512;
#ifdef __HIPCC__
    block = 256;
#endif
    Functor functor;
335 336 337 338
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    constexpr int vecsize = CudaVecType<T>::vecsize;
    int grid = max((num / vecsize + block - 1) / block, 1);
    auto stream = context.cuda_device_context().stream();
    ActivationkernelVec<T, Functor><<<grid, block, 0, stream>>>(
        input_data, output_data, num, functor);
  }
};

template <typename DeviceContext, typename Functor>
class ActivationGradGPUKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor *x, *out, *d_out;
    framework::Tensor* d_x = nullptr;
    x = out = d_out = nullptr;
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &x, &out, &d_out,
                                                    &d_x);
    int numel = d_out->numel();
    auto& dev_ctx = context.template device_context<DeviceContext>();
    auto* dx_data = d_x->mutable_data<T>(
        dev_ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
    auto* dout_data = d_out->data<T>();

    auto* forward_data = dout_data;
    if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) {
      // Only need forward output Out
      forward_data = out->data<T>();
    } else if (static_cast<int>(Functor::FwdDeps()) ==
               static_cast<int>(kDepX)) {
      // Only need forward input X
      forward_data = x->data<T>();
    }

    int block = 512;
#ifdef __HIPCC__
    block = 256;
#endif
378

379
    Functor functor;
380 381 382 383
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
384 385 386 387 388 389 390 391 392 393 394
    constexpr int vecsize = CudaVecType<T>::vecsize;
    int grid = max((numel / vecsize + block - 1) / block, 1);
    auto stream = context.cuda_device_context().stream();
    ActivationGradKernelVec<T, Functor><<<grid, block, 0, stream>>>(
        forward_data, dout_data, dx_data, numel, functor);
  }
};

}  // namespace operators
}  // namespace paddle

395
namespace ops = paddle::operators;
396 397
namespace plat = paddle::platform;

398 399
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor,         \
                                        grad_functor)                       \
400 401 402 403 404 405 406 407 408 409
  REGISTER_OP_CUDA_KERNEL(                                                  \
      act_type,                                                             \
      ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>,  \
      ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
      ops::ActivationKernel<plat::CUDADeviceContext,                        \
                            ops::functor<plat::float16>>);                  \
  REGISTER_OP_CUDA_KERNEL(                                                  \
      act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,   \
                                                 ops::grad_functor<float>>, \
      ops::ActivationGradKernel<plat::CUDADeviceContext,                    \
C
chengduo 已提交
410 411 412
                                ops::grad_functor<double>>,                 \
      ops::ActivationGradKernel<plat::CUDADeviceContext,                    \
                                ops::grad_functor<plat::float16>>);
413
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
414

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor,             \
                                       grad_functor)                           \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type, ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,  \
                                         ops::functor<float>>,                 \
      ops::ActivationGPUKernel<paddle::platform::CUDADeviceContext,            \
                               ops::functor<double>>,                          \
      ops::ActivationGPUKernel<plat::CUDADeviceContext,                        \
                               ops::functor<plat::float16>>);                  \
  REGISTER_OP_CUDA_KERNEL(                                                     \
      act_type##_grad, ops::ActivationGradGPUKernel<plat::CUDADeviceContext,   \
                                                    ops::grad_functor<float>>, \
      ops::ActivationGradGPUKernel<plat::CUDADeviceContext,                    \
                                   ops::grad_functor<double>>,                 \
      ops::ActivationGradGPUKernel<plat::CUDADeviceContext,                    \
                                   ops::grad_functor<plat::float16>>);

432
/* ======================== leaky relu register  ============================ */
433 434
REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor,
                               LeakyReluGradGPUFunctor);
435 436 437 438 439 440 441 442 443

REGISTER_OP_CUDA_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
444
/* ========================================================================== */
445

D
Double_V 已提交
446 447 448 449 450 451 452 453 454 455 456 457
/* ======================== elu register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::ELUGradGradFunctor<float>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<double>>,
    ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
                             ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

458
/* ===========================    relu register  ============================ */
459
REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor);
460 461 462 463 464 465 466 467 468

REGISTER_OP_CUDA_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
469 470
/* ========================================================================== */

471 472 473 474 475 476 477 478 479 480 481 482 483
/* ===========================    tanh register  ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    tanh_grad_grad,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<float>>,
    ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::TanhGradGradFunctor<double>>,
    ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
                              ops::TanhGradGradFunctor<plat::float16>>);
/* ========================================================================== */

L
lvmengsi 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496
/* ===========================   sqrt register  ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    sqrt_grad_grad,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

W
whs 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509 510
/* ===========================   rsqrt register  =============================
 */
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

511
/* ===========================  square register  ============================ */
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
REGISTER_OP_CUDA_KERNEL(
    square,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<float>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<double>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::SquareFunctor<int64_t>>,
    ops::ActivationKernel<plat::CUDADeviceContext,
                          ops::SquareFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    square_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,
                                           ops::SquareGradFunctor<float>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::SquareGradFunctor<double>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::SquareGradFunctor<int>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::SquareGradFunctor<int64_t>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::SquareGradFunctor<plat::float16>>);
531 532 533 534 535 536 537 538

REGISTER_OP_CUDA_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
539 540 541 542 543
                                ops::SquareGradGradFunctor<plat::float16>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
544
/* ========================================================================== */
545 546 547 548 549 550

/* ==========================   pow register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
551 552
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
553 554 555 556 557
    ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
558 559
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
560 561 562
    ops::PowGradKernel<plat::CUDADeviceContext,
                       ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585

/* ==========================   exp register  ============================ */

REGISTER_OP_CUDA_KERNEL(
    exp, ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<float>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<double>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
    ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
    ops::ActivationKernel<plat::CUDADeviceContext,
                          ops::ExpFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
    exp_grad, ops::ActivationGradKernel<plat::CUDADeviceContext,
                                        ops::ExpGradFunctor<float>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::ExpGradFunctor<double>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::ExpGradFunctor<int>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::ExpGradFunctor<int64_t>>,
    ops::ActivationGradKernel<plat::CUDADeviceContext,
                              ops::ExpGradFunctor<plat::float16>>);
/* ========================================================================== */

586 587 588 589 590 591 592 593 594 595 596
/* ==========================  Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor);

REGISTER_OP_CUDA_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CUDADeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */