activation_functor.h 131.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17
#include <glog/logging.h>
18

19
#include <algorithm>
20
#include <cmath>
21 22 23 24 25 26 27 28 29 30 31 32
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

#include <type_traits>

#include "paddle/phi/common/amp_type_traits.h"
Y
YuanRisheng 已提交
33
#include "paddle/phi/common/bfloat16.h"
34 35 36 37
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
Y
YuanRisheng 已提交
38
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
39

40 41 42 43
#ifdef PADDLE_WITH_XPU_KP
#define __forceinline__ __inline__
#endif

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
namespace phi {
namespace funcs {
enum ActBwdOpFwdDeps {
  kNoDeps = 0x00,  // Do not need any forward input/output
  kDepX = 0x01,    // Only need forward input X
  kDepOut = 0x02,  // Only need forward output Out
};

template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

template <typename T>
struct Sine {
  HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};

template <>
struct Sine<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(sin(static_cast<float>(val)));
  }
};

73 74 75 76 77 78 79
template <>
struct Sine<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(sin(static_cast<float>(val)));
  }
};

80 81 82 83 84 85 86 87 88 89 90 91
template <typename T>
struct Cosine {
  HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};

template <>
struct Cosine<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(cos(static_cast<float>(val)));
  }
};

92 93 94 95 96 97 98
template <>
struct Cosine<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(cos(static_cast<float>(val)));
  }
};

99 100 101 102 103 104 105 106
// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
107
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
108 109 110 111 112 113 114 115 116 117 118
    dx.device(d) = dout * x.unaryExpr(Cosine<T>());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
119 120 121 122
    // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
    // will give wrong result, details see
    // http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
    out.device(d) = x.unaryExpr(Sine<T>()).eval();
123
  }
124 125 126 127 128 129 130 131 132 133 134 135 136
};

// sine''(x) = -sin(x)
template <typename T>
struct SinDoubleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* dOut,
                  const DenseTensor* ddX,
                  DenseTensor* dX,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
137 138
    auto d2d1x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinDoubleGrad"));
139
    auto x = EigenVector<T>::Flatten(
140 141 142 143 144
        GET_DATA_SAFELY(X, "Input", "x", "SinDoubleGrad"));

    // calculate d2x first, so d2d1y can inplace d2d1x
    auto d2x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad"));
145 146 147 148 149 150 151 152 153 154

    if (dX) {
      if (dOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad"));
        d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * d1y;
      } else {
        d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * static_cast<T>(0);
      }
    }
155 156

    // calculate d2d1y
157 158 159 160 161
    if (ddOut) {
      auto d2d1y = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad"));
      d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine<T>());
    }
162 163
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
};

// 1st reverse grad
// y = sin(x)
// x --> y
// d1x = d1y * cos(x)
//
// 2nd reverse grad
// x, d1y --> d1x
// d2x = -sin(x) * d1y * d2d1x
// d2d1y = cos(x) * d2d1x
//
// 3rd reverse grad
// x, d1y, d2d1x --> d2x, d2d1y
// d3x = -cos(x) * d1y * d2d1x * d3d2x - sin(x) * d2d1x * d3d2d1y
// d3d1y = -sin(x) * d2d1x * d3d2x
// d3d2d1x = -sin(x) * d1y * d3d2x + cos(x) * d3d2d1y
template <typename T>
struct SinTripleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  const DenseTensor* d_DDOut,
                  const DenseTensor* d_dx_New,
                  DenseTensor* d_d_Out,
                  DenseTensor* d_x_New,
                  DenseTensor* d_DDx) const {
    auto* d = dev.eigen_device();
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "x", "SinTripleGrad"));
    auto d3d2x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "SinTripleGrad"));
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
    if (d_x_New) {
      auto d3x = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_x_New, "Output", "d3x", "SinTripleGrad"));
      if (dOut && ddX && d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
        d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x -
                         x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
      } else if (!dOut && ddX && d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
        d3x.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2d1y;
      } else if (dOut && ddX && !d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
        d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d2d1x * d3d2x;
      } else {
        d3x.device(*d) = x * static_cast<T>(0);
      }
    }
226

227 228 229 230 231 232 233 234 235 236 237
    if (d_d_Out) {
      auto d3d1y = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "SinTripleGrad"));
      if (ddX) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinTripleGrad"));
        d3d1y.device(*d) = -x.unaryExpr(Sine<T>()) * d2d1x * d3d2x;
      } else {
        d3d1y.device(*d) = static_cast<T>(0) * x;
      }
    }
238

239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
    if (d_DDx) {
      auto d3d2d1x = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "SinTripleGrad"));
      if (dOut && d_DDOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
        d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x +
                             x.unaryExpr(Cosine<T>()) * d3d2d1y;
      } else if (dOut && !d_DDOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "SinTripleGrad"));
        d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d1y * d3d2x;
      } else if (!dOut && d_DDOut) {
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "SinTripleGrad"));
        d3d2d1x.device(*d) = x.unaryExpr(Cosine<T>()) * d3d2d1y;
      } else {
        d3d2d1x.device(*d) = x * static_cast<T>(0);
      }
    }
261 262 263 264
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
265 266
};

267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / x;
  }
};

template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
283
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
284 285 286 287 288 289 290 291
    dx.device(d) = dout * static_cast<T>(-1) * out * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
// 1st reverse grad
// y = cos(x)
// x --> y
// d1x = d1y * -sin(x)
//
// 2nd reverse grad
// x, d1y --> d1x
// d2x = -cos(x) * d1y * d2d1x
// d2d1y = -sin(x) * d2d1x
//
// 3rd reverse grad
// x, d1y, d2d1x --> d2x, d2d1y
// d3x = sin(x) * d1y * d2d1x * d3d2x - cos(x) * d2d1x * d3d2d1y
// d3d1y = -cos(x) * d2d1x * d3d2x
// d3d2d1x = -cos(x) * d1y * d3d2x - sin(x) * d3d2d1y

308 309 310 311 312 313 314 315
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
316
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
317 318 319 320 321 322
    dx.device(d) = -dout * x.unaryExpr(Sine<T>());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
// cos''(x) = -cos(x)
template <typename T>
struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* dOut,
                  const DenseTensor* ddX,
                  DenseTensor* dX,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto d2d1x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosDoubleGrad"));
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "x", "CosDoubleGrad"));

    // calculate d2x first, so d2d1y can inplace d2d1x
    auto d2x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad"));
342 343 344 345 346 347 348 349 350
    if (ddOut) {
      if (dOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad"));
        d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine<T>()) * d1y;
      } else {
        d2x.device(*d) = x * static_cast<T>(0);
      }
    }
351

352 353 354 355 356 357
    if (dX) {
      // calculate d2d1y
      auto d2d1y = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad"));
      d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine<T>());
    }
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CosTripleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  const DenseTensor* d_DDOut,
                  const DenseTensor* d_dx_New,
                  DenseTensor* d_d_Out,
                  DenseTensor* d_x_New,
                  DenseTensor* d_DDx) const {
    auto* d = dev.eigen_device();
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad"));
    auto d3d2x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
    if (d_x_New) {
      auto d3x = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
      if (dOut && ddX && d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
        d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x -
                         x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
      } else if (dOut && ddX && !d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
        d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x;
      } else if (!dOut && ddX && d_DDOut) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
        d3x.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
      } else {
        d3x.device(*d) = static_cast<T>(0) * x;
      }
    }
408

409 410 411 412 413 414 415 416 417 418 419
    if (d_d_Out) {
      auto d3d1y = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad"));
      if (ddX) {
        auto d2d1x = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
        d3d1y.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2x;
      } else {
        d3d1y.device(*d) = static_cast<T>(0) * x;
      }
    }
420

421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
    if (d_DDx) {
      auto d3d2d1x = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad"));
      if (dOut && d_DDOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
        d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x -
                             x.unaryExpr(Sine<T>()) * d3d2d1y;
      } else if (!dOut && d_DDOut) {
        auto d3d2d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
        d3d2d1x.device(*d) = -x.unaryExpr(Sine<T>()) * d3d2d1y;
      } else if (dOut && !d_DDOut) {
        auto d1y = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
        d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x;
      } else {
        d3d2d1x.device(*d) = static_cast<T>(0) * x;
      }
    }
443 444 445 446 447 448
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

449 450 451 452 453
// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
454
    out.device(d) = x.unaryExpr(Cosine<T>()).eval();
455 456 457
  }
};

458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
template <typename T>
struct LogitFunctor {
  template <typename Device, typename X, typename Out, typename P>
  void operator()(Device d, X x, Out out, P p, float eps) const {
    // logit(x) = ln(x/(1-x))
    auto tmp_x =
        (x.cwiseMin(static_cast<T>(1.0 - eps))).cwiseMax(static_cast<T>(eps));

    if (!eps) {
      out.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
                          .select(p.constant(static_cast<T>(NAN)),
                                  (tmp_x / (static_cast<T>(1) - tmp_x)).log());
    } else {
      out.device(d) = (tmp_x / (static_cast<T>(1) - tmp_x)).log();
    }
  }
};

// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
//             = ln(1 + exp(x)), otherwise

template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto sp = (x > static_cast<T>(threshold))
                  .select(x, (static_cast<T>(1) + x.exp()).log());
    out.device(d) = x * sp.tanh();
  }
};

// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)

template <typename T>
struct MishGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
511
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
    auto sp = (x > static_cast<T>(threshold))
                  .select(x, (static_cast<T>(1) + x.exp()).log());
    auto gsp = static_cast<T>(1) - (-sp).exp();
    auto tsp = sp.tanh();
    dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
        static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
  }
};

template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
  float scale_a;
  float scale_b;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
550
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
551 552 553 554 555 556 557 558 559
    auto a = static_cast<T>(scale_a);
    auto b = static_cast<T>(scale_b);
    auto temp = (a * x).tanh() * (a * x).tanh();
    dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
template <typename T>
struct Tangent {
  HOSTDEVICE T operator()(const T& val) const { return tan(val); }
};

template <>
struct Tangent<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(tan(static_cast<float>(val)));
  }
};

// Tangent'(x) = -Tangent(x)
template <typename T>
struct TanGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
580
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
581 582 583 584 585 586
    dx.device(d) = dout / x.unaryExpr(Cosine<T>()).square();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
  }
};

template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
603
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
    dx.device(d) = dout * static_cast<T>(2) * x;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.sqrt();
  }
};

template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
626
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
    dx.device(d) = static_cast<T>(0.5) * dout / out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.rsqrt();
  }
};

template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
651
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
    dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

// // For numerical stability, using the following formula instead of
// softplus(x) =
// // log(1 + exp(x))
// // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <=
// threshold(beta =
// // 1, threshold = 20 by default), otherwise x

template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto x_beta = static_cast<T>(beta) * x;
    out.device(d) = (x_beta > static_cast<T>(threshold))
                        .select(x,
                                (static_cast<T>(1) + x_beta.exp()).log() /
                                    static_cast<T>(beta));
  }
};

// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
// = 1, threshold = 20 by default), otherwise x

template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
703
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
704 705 706 707 708 709 710 711 712
    auto x_beta = static_cast<T>(beta) * x;
    dx.device(d) =
        (x_beta > static_cast<T>(threshold))
            .select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

W
will-jl944 已提交
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
template <typename T>
struct SoftplusDoubleGradFunctor : public BaseActivationFunctor<T> {
  float beta;
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* dOut,
                  const DenseTensor* ddX,
                  DenseTensor* dX,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "SoftplusDoubleGrad"));
    auto x_beta = static_cast<T>(beta) * x;
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "SoftplusDoubleGrad"));

    if (dX) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "SoftplusDoubleGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "SoftplusDoubleGrad"));
      // ddx * dout * beta * exp(x_beta) / (exp(x_beta) + 1) ^ 2, if x_beta
      // <= threshold
      // 0, if x_beta > threshold
      dx.device(*d) =
          (x_beta > static_cast<T>(threshold))
              .select(x.constant(static_cast<T>(0)),
                      ddx * dout * static_cast<T>(beta) * x_beta.exp() /
                          (x_beta.exp() + static_cast<T>(1))
                              .pow(static_cast<T>(2)));
    }

    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SoftplusDoubleGrad"));
      // ddx / (1 + exp(-x_beta)), if x_beta <= threshold
      // ddx, if x_beta > threshold
      ddout.device(*d) =
          (x_beta > static_cast<T>(threshold))
              .select(ddx, ddx / (static_cast<T>(1) + (-x_beta).exp()));
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

763 764 765 766 767
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
768 769 770 771
    // Note(GGBond8488): Since Eigen3.3, Behavior like {A = (B * A).cwiseAbs()}
    // will give wrong result, details see
    // http://eigen.tuxfamily.org/dox/group__TopicAliasing.html
    out.device(d) = x.unaryExpr(Tangent<T>()).eval();
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
  }
};

template <typename T>
struct Sinh {
  HOSTDEVICE T operator()(const T& val) const { return sinh(val); }
};

template <>
struct Sinh<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(sinhf(static_cast<float>(val)));
  }
};

template <typename T>
struct Cosh {
  HOSTDEVICE T operator()(const T& val) const { return cosh(val); }
};

template <>
struct Cosh<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(coshf(static_cast<float>(val)));
  }
};

// sinh(x) = sinh(x)
template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
804
    out.device(d) = x.unaryExpr(Sinh<T>()).eval();
805 806 807 808 809 810 811 812
  }
};

// cosh(x) = cosh(x)
template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
813
    out.device(d) = x.unaryExpr(Cosh<T>()).eval();
814 815 816 817 818 819 820 821 822 823 824
  }
};

// sinh'(x) = cosh(x)
template <typename T>
struct SinhGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
825
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
826 827 828 829 830 831 832 833 834 835 836 837 838 839
    dx.device(d) = dout * x.unaryExpr(Cosh<T>());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// cosh'(x) = sinh(x)
template <typename T>
struct CoshGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
840
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863
    dx.device(d) = dout * x.unaryExpr(Sinh<T>());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct Acos {
  HOSTDEVICE T operator()(const T& val) const { return acos(val); }
};

template <>
struct Acos<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(acos(static_cast<float>(val)));
  }
};

// Acos(x) = acos(x)
template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
864
    out.device(d) = x.unaryExpr(Acos<T>()).eval();
865 866 867 868 869 870 871 872 873 874 875
  }
};

// acos'(x) = -1/sqrt(1-x^2)
template <typename T>
struct AcosGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
876
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900
    dx.device(d) =
        -dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct Asin {
  HOSTDEVICE T operator()(const T& val) const { return asin(val); }
};

template <>
struct Asin<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(asin(static_cast<float>(val)));
  }
};

// Asin(x) = asin(x)
template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
901
    out.device(d) = x.unaryExpr(Asin<T>()).eval();
902 903 904 905 906 907 908 909 910 911 912
  }
};

// asin'(x) = 1/sqrt(1-x^2)
template <typename T>
struct AsinGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
913
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937
    dx.device(d) =
        dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct Atan {
  HOSTDEVICE T operator()(const T& val) const { return atan(val); }
};

template <>
struct Atan<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(atan(static_cast<float>(val)));
  }
};

// Atan(x) = atan(x)
template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
938
    out.device(d) = x.unaryExpr(Atan<T>()).eval();
939 940 941 942 943 944 945 946 947 948 949
  }
};

// atan'(x) =  1 / (1 + x^2)
template <typename T>
struct AtanGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
950
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
951 952 953 954 955 956
    dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

957 958 959 960 961 962 963 964 965 966 967 968
template <typename T>
struct LogitGradFunctor {
  template <typename Device, typename X, typename dOut, typename dX, typename P>
  void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
    // logit(x)' = 1/(x*(1-x))
    dx.device(d) =
        (x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
            .select(p.constant(static_cast<T>(0)),
                    dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
  }
};

969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985
template <typename T>
struct Acosh {
  HOSTDEVICE T operator()(const T& val) const { return acosh(val); }
};

template <>
struct Acosh<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(acosh(static_cast<float>(val)));
  }
};

// Acosh(x) = acosh(x)
template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
986
    out.device(d) = x.unaryExpr(Acosh<T>()).eval();
987 988 989 990 991 992 993 994 995 996 997
  }
};

// acosh'(x) =  1/sqrt(x^2 - 1)
template <typename T>
struct AcoshGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
998
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022
    dx.device(d) =
        dout * static_cast<T>(1) / (x * x - static_cast<T>(1)).sqrt();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct Asinh {
  HOSTDEVICE T operator()(const T& val) const { return asinh(val); }
};

template <>
struct Asinh<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(asinh(static_cast<float>(val)));
  }
};

// Asinh(x) = asinh(x)
template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
1023
    out.device(d) = x.unaryExpr(Asinh<T>()).eval();
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
  }
};

// asinh'(x) =  1/sqrt(x^2 + 1)
template <typename T>
struct AsinhGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1035
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
    dx.device(d) =
        dout * static_cast<T>(1) / (x.square() + static_cast<T>(1)).sqrt();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct Atanh {
  HOSTDEVICE T operator()(const T& val) const { return atanh(val); }
};

template <>
struct Atanh<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(atanh(static_cast<float>(val)));
  }
};

// Atanh(x) = atanh(x)
template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
G
GGBond8488 已提交
1060
    out.device(d) = x.unaryExpr(Atanh<T>()).eval();
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
  }
};

// atanh'(x) =  1/(1 - x^2)
template <typename T>
struct AtanhGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1072
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1073 1074 1075 1076 1077 1078
    dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) - x.square());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

1079 1080 1081 1082
// exp functor
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
1083 1084
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

1085 1086
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
1087
    out.device(d) = x.template cast<U>().exp();
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
  }
};

template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1098
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
    dx.device(d) = dout * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

// expm1(x) = e^x - 1
template <typename T>
struct Expm1Functor : public BaseActivationFunctor<T> {
1110 1111
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

1112 1113
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
1114
    out.device(d) = x.template cast<U>().expm1();
1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
  }
};

template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1125
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
1126 1127 1128 1129 1130 1131 1132 1133
    dx.device(d) = dout * out + dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159
// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
      return v > static_cast<T>(0) ? v : static_cast<T>(0);
    });
  }
};

template <typename T>
struct ReluCUDAFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
  }
};

template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1160
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
    dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
1173
                  const DenseTensor* X UNUSED,
1174 1175 1176
                  const DenseTensor* Out,
                  const DenseTensor* ddX,
                  DenseTensor* ddOut,
1177 1178
                  DenseTensor* dOut UNUSED,
                  DenseTensor* dX UNUSED) const {
1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
      ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
  }
};

template <typename T>
struct TanhGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1211
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290
    dx.device(d) = dout * (static_cast<T>(1) - out * out);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  DenseTensor* dOutNew,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhGradGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "TanhGradGrad"));
    // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out
    // * ddx)
    if (dOutNew) {
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad"));
      auto dout_new = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "TanhGradGrad"));
      dout_new.device(*d) =
          static_cast<T>(-1) * dout * static_cast<T>(2) * out * ddx;
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "TanhGradGrad"));
      ddout.device(*d) = (static_cast<T>(1) - out * out) * ddx;
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};
/*
    Out
    DOut                            D_Dout
    DDx     -> TanhTripleGrad ->    D_DDx
    D_DDout                         d_OutNew
    D_Dout_new

    D_Dout = (-2) * Out * DDx * D_Dout_new
    D_DDx = (1-Out^2)*D_DDout + (-2) * Out * DOut * D_Dout_new
    D_OutNew = (-2) * Out * DDx * D_DDout + (-2) * DOut * DDx * D_Dout_new

    Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    D_OutNew, D_DOut, D_DDx               // output
*/
template <typename T>
struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  const DenseTensor* d_DDOut,
                  const DenseTensor* d_dOut_New,
                  DenseTensor* d_d_Out,
                  DenseTensor* d_Out_New,
                  DenseTensor* d_DDx) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhTripleGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad"));
    auto dout = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad"));

    if (d_Out_New) {
      auto d_OutNew = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad"));
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314

      if (d_DDOut && d_dOut_New) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
        auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
            d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));

        d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut) -
                              (static_cast<T>(2) * dout * ddx * d_dOutNew);

      } else if (d_DDOut && !d_dOut_New) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));

        d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut);

      } else if (!d_DDOut && d_dOut_New) {
        auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
            d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));

        d_OutNew.device(*d) = -(static_cast<T>(2) * dout * ddx * d_dOutNew);
      } else {
        d_OutNew.device(*d) = static_cast<T>(0) * out;
      }
1315 1316 1317 1318
    }
    if (d_d_Out) {
      auto d_dOut = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad"));
1319 1320 1321 1322 1323 1324 1325 1326

      if (d_dOut_New) {
        auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
            d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
        d_dOut.device(*d) = static_cast<T>(-2) * out * ddx * d_dOutNew;
      } else {
        d_dOut.device(*d) = static_cast<T>(0) * out;
      }
1327 1328 1329 1330
    }
    if (d_DDx) {
      auto d_ddx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad"));
1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350

      if (d_DDOut && d_dOut_New) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
        auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
            d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
        d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut -
                           static_cast<T>(2) * out * dout * d_dOutNew;

      } else if (d_DDOut && !d_dOut_New) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
        d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut;
      } else if (!d_DDOut && d_dOut_New) {
        auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
            d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
        d_ddx.device(*d) = -static_cast<T>(2) * out * dout * d_dOutNew;
      } else {
        d_ddx.device(*d) = static_cast<T>(0) * ddx;
      }
1351 1352 1353 1354 1355 1356 1357 1358
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
1359
struct HardTanhFunctor : public BaseActivationFunctor<T> {
1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376
  float t_min;
  float t_max;

  // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
  // not polymorphism for speed.
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
        x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
  }
};

template <typename T>
1377
struct HardTanhGradFunctor : public BaseActivationFunctor<T> {
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387
  float t_min;
  float t_max;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1388
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1389 1390 1391
    dx.device(d) =
        dout * ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
                   .template cast<T>();
1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    if (alpha < 1.f) {
      out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
    } else {
      out.device(d) = x.cwiseMin(static_cast<T>(alpha) * x);
    }
  }
};

template <typename T>
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1425
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443
    auto temp1 =
        static_cast<T>(alpha) * (x < static_cast<T>(0)).template cast<T>();
    auto temp2 = (x >= static_cast<T>(0)).template cast<T>();
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
1444
                  const DenseTensor* Out UNUSED,
1445 1446
                  const DenseTensor* ddX,
                  DenseTensor* ddOut,
1447 1448
                  DenseTensor* dOut UNUSED,
                  DenseTensor* dX UNUSED) const {
1449 1450 1451 1452 1453 1454 1455 1456
    if (ddOut) {
      auto* d = dev.eigen_device();
      auto ddx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
      auto x = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad"));
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
1457 1458 1459 1460
      ddout.device(*d) = ddx * ((x > static_cast<T>(0)).template cast<T>() +
                                static_cast<T>(alpha) *
                                    (x <= static_cast<T>(0)).template cast<T>())
                                   .template cast<T>();
1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto th = static_cast<T>(threshold);
    out.device(d) = (x > th).template cast<T>() * x;
  }
};

template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1492
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1493 1494 1495 1496 1497 1498 1499
    auto th = static_cast<T>(threshold);
    dx.device(d) = dout * (x > th).template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
1518
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() { return {{}}; }
1519 1520 1521 1522 1523
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1524
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
1525
    float threshold = 6;
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535
    dx.device(d) =
        dout * ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
                   .template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

Y
YuanRisheng 已提交
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x - x.tanh();
  }
};

template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1553
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589
    dx.device(d) = dout * (x.tanh() * x.tanh());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto temp1 = x < static_cast<T>(threshold * -1.f);
    auto temp2 = x > static_cast<T>(threshold);
    out.device(d) = x * (temp1 || temp2).template cast<T>();
  }
};

template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1590
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627
    auto temp1 = x < static_cast<T>(threshold * -1.f);
    auto temp2 = x > static_cast<T>(threshold);
    dx.device(d) = dout * (temp1 || temp2).template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>();
    auto temp2 = (x < -lambdaT).template cast<T>();
    out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
  }
};

template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  float lambda;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1628
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663
    auto lambdaT = static_cast<T>(lambda);
    auto temp1 = (x > lambdaT).template cast<T>();
    auto temp2 = (x < -lambdaT).template cast<T>();
    dx.device(d) = dout * (temp1 + temp2).template cast<T>();
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
        (x < static_cast<T>(0))
            .select(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)), x);
  }
};

template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
G
Galaxy1458 已提交
1664
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685
    // case 1: alpha >= 0
    // dx = dout, if out > 0
    // dx = dout * (out + alpha), if out <= 0
    dx.device(d) = (out > static_cast<T>(0))
                       .select(dout, dout * (out + static_cast<T>(alpha)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1686
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727
    // case 2: alpha < 0
    // dx = dout, if x > 0
    // dx = dout * (out + alpha), if x <=0
    dx.device(d) = (x > static_cast<T>(0))
                       .select(dout, dout * static_cast<T>(alpha) * x.exp());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* ddX,
                  DenseTensor* ddOut,
                  const DenseTensor* dOut,
                  DenseTensor* dX) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));

    if (dX) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
      dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
                      (x <= static_cast<T>(0)).template cast<T>();
    }

    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
1728 1729 1730 1731
      ddout.device(*d) = ddx * ((x > static_cast<T>(0)).template cast<T>() +
                                static_cast<T>(alpha) * x.exp() *
                                    (x <= static_cast<T>(0)).template cast<T>())
                                   .template cast<T>();
Y
YuanRisheng 已提交
1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto temp = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
    out.device(d) = x * temp;
  }
};

// silu'(x) = (1 / (1 + e^{-x}))  * (1 + out * e^{-x}))
template <typename T>
struct SiluGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1755
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1756 1757 1758 1759 1760 1761 1762 1763 1764
    auto temp1 = static_cast<T>(1) + (-x).exp();  // 1+e^(-x)
    auto temp2 = x * (-x).exp();                  // x*e^(-x)
    dx.device(d) = dout * ((static_cast<T>(1) / temp1) *
                           (static_cast<T>(1) + (temp2 / temp1)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1783
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
1784 1785 1786 1787 1788 1789 1790
    dx.device(d) =
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Y
YuanRisheng 已提交
1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
  }
};

template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1807
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896
    dx.device(d) = dout * out * (static_cast<T>(1) - out);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

/*
    Out
    DOut -> SigmoidGradGrad -> DOutNew
    DDX                        DDOut

    DDOut = (1-Out)*Out*DDX
    DOutNew = (1-2*Out)*DOut*DDX
*/
template <typename T>
struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  DenseTensor* dOutNew,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidGradGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidGradGrad"));

    if (dOutNew) {
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidGradGrad"));
      auto dout_new = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SigmoidGradGrad"));
      dout_new.device(*d) =
          (static_cast<T>(1) - static_cast<T>(2) * out) * dout * ddx;
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SigmoidGradGrad"));
      ddout.device(*d) = (static_cast<T>(1) - out) * out * ddx;
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

/*
    Out
    DOut                            D_Dout
    DDx     -> SigmoidTripleGrad -> D_DDx
    D_DDout                         d_OutNew
    D_Dout_new

    D_Dout = (1-2*Out)*DDx*D_Dout_new
    D_DDx = (1-Out)*Out*D_DDout + (1-2*Out)*DOut*D_Dout_new
    D_OutNew = (DDx-2*Out*DDx)*D_DDout - 2*DOut*DDx*D_Dout_new

    Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    D_OutNew, D_DOut, D_DDx               // output
*/
template <typename T>
struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* ddX,
                  const DenseTensor* dOut,
                  const DenseTensor* d_DDOut,
                  const DenseTensor* d_dOut_New,
                  DenseTensor* d_d_Out,
                  DenseTensor* d_Out_New,
                  DenseTensor* d_DDx) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTripleGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTripleGrad"));
    auto dout = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTripleGrad"));
    auto d_dOutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
        d_dOut_New, "Input", "D_DOut_New", "SigmoidTripleGrad"));

    if (d_Out_New) {
      auto d_OutNew = EigenVector<T>::Flatten(GET_DATA_SAFELY(
          d_Out_New, "Output", "D_OutNew", "SigmoidTripleGrad"));
1897 1898 1899 1900 1901 1902 1903
      d_OutNew.device(*d) = -static_cast<T>(2) * dout * ddx * d_dOutNew;
      if (d_DDOut) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad"));
        d_OutNew.device(*d) =
            (ddx - static_cast<T>(2) * out * ddx) * d_ddOut + d_OutNew;
      }
Y
YuanRisheng 已提交
1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915
    }
    if (d_d_Out) {
      auto d_dOut = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTripleGrad"));
      d_dOut.device(*d) =
          (static_cast<T>(1) - static_cast<T>(2) * out) * ddx * d_dOutNew;
    }
    if (d_DDx) {
      auto d_ddx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTripleGrad"));
      d_ddx.device(*d) =
          (static_cast<T>(1) - static_cast<T>(2) * out) * dout * d_dOutNew;
1916 1917 1918 1919 1920
      if (d_DDOut) {
        auto d_ddOut = EigenVector<T>::Flatten(
            GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad"));
        d_ddx.device(*d) = d_ddx + (static_cast<T>(1) - out) * out * d_ddOut;
      }
Y
YuanRisheng 已提交
1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
//   = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
//   = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
//           max(-x, 0)))
//   = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
//   = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
//
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
  }
};

// Originally: f' = exp(-x) / (1 + exp(-x))
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1960
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
    auto temp = (-x).cwiseMax(static_cast<T>(0));  // temp = max(-x, 0)
    dx.device(d) =
        dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
    out.device(d) =
        temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
  }
};

template <typename T>
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  float slope;
  float offset;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
1997
  void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008
    dx.device(d) = dout *
                   ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
                       .template cast<T>() *
                   static_cast<T>(slope);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027
template <typename T>
struct Log {
  HOSTDEVICE T operator()(const T& val) const { return std::log(val); }
};

template <>
struct Log<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(std::log(static_cast<float>(val)));
  }
};

template <>
struct Log<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(std::log(static_cast<float>(val)));
  }
};

2028 2029 2030
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
2031 2032
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

2033 2034
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
2035
    out.device(d) = x.template cast<U>().unaryExpr(Log<U>()).eval();
2036 2037 2038 2039 2040 2041 2042 2043 2044 2045
  }
};

template <typename T>
struct LogGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2046
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2047 2048 2049 2050 2051 2052
    dx.device(d) = dout * (static_cast<T>(1) / x);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071
template <typename T>
struct Log2 {
  HOSTDEVICE T operator()(const T& val) const { return std::log2(val); }
};

template <>
struct Log2<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(std::log2(static_cast<float>(val)));
  }
};

template <>
struct Log2<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(std::log2(static_cast<float>(val)));
  }
};

2072 2073 2074
// log2(x) = logarithm to the base 2 of the elements of x
template <typename T>
struct Log2Functor : public BaseActivationFunctor<T> {
2075 2076
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

2077 2078
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
2079
    out.device(d) = x.template cast<U>().unaryExpr(Log2<U>()).eval();
2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090
  }
};

// the gradient of log2(x) is 1/(x*ln(2))
template <typename T>
struct Log2GradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2091
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2092 2093 2094 2095 2096 2097
    dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(2)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116
template <typename T>
struct Log10 {
  HOSTDEVICE T operator()(const T& val) const { return std::log10(val); }
};

template <>
struct Log10<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(std::log10(static_cast<float>(val)));
  }
};

template <>
struct Log10<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(std::log10(static_cast<float>(val)));
  }
};

2117 2118 2119
// log10(x) = logarithm to the base 10 of the elements of x
template <typename T>
struct Log10Functor : public BaseActivationFunctor<T> {
2120 2121
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

2122 2123
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
2124
    out.device(d) = x.template cast<U>().unaryExpr(Log10<U>()).eval();
2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135
  }
};

// the gradient of log10(x) is 1/(x*ln(10))
template <typename T>
struct Log10GradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2136
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2137 2138 2139 2140 2141 2142
    dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(10)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161
template <typename T>
struct Log1p {
  HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); }
};

template <>
struct Log1p<dtype::float16> {
  HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
    return dtype::float16(std::log1p(static_cast<float>(val)));
  }
};

template <>
struct Log1p<dtype::bfloat16> {
  HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
    return dtype::bfloat16(std::log1p(static_cast<float>(val)));
  }
};

2162 2163 2164
// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
2165 2166
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

2167 2168
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
2169
    out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>()).eval();
2170 2171 2172 2173 2174 2175 2176 2177 2178 2179
  }
};

template <typename T>
struct Log1pGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2180
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219
    dx.device(d) = dout * (static_cast<T>(1) / (x + static_cast<T>(1)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* ddX,
                  DenseTensor* ddOut,
                  const DenseTensor* dOut,
                  DenseTensor* dX) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
    // ddout = ddx / x; dx = -(dout / x) * (ddx / x)
    // calculate dx first, so ddout can inplace ddx
    if (dX) {
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
      dx.device(*d) = dout * static_cast<T>(-1) * ddx / (x * x);
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
      ddout.device(*d) = ddx * static_cast<T>(1) / x;
    }
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Y
YuanRisheng 已提交
2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253
// HardSwish = min(max(0, x+3), 6) * x / 6
template <typename T>
struct HardSwishFunctor : public BaseActivationFunctor<T> {
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = (x + static_cast<T>(offset))
                        .cwiseMax(static_cast<T>(0))
                        .cwiseMin(static_cast<T>(threshold)) *
                    x / static_cast<T>(scale);
  }
};

template <typename T>
struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2254
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282
    auto tmp = ((x + static_cast<T>(offset)) < static_cast<T>(threshold))
                   .template cast<T>();
    dx.device(d) =
        dout *
        (((x + static_cast<T>(offset)) > static_cast<T>(0)).template cast<T>() *
             (static_cast<T>(2) * x + static_cast<T>(offset)) /
             static_cast<T>(scale) * tmp +
         static_cast<T>(1) * (static_cast<T>(1) - tmp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
  float beta;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
  }
};

template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
2283
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() { return {{}}; }
Y
YuanRisheng 已提交
2284 2285 2286 2287 2288 2289

  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2290
  void operator()(Device d, X x, Out fake_out UNUSED, dOut dout, dX dx) const {
2291
    float beta = 1.0;
Y
YuanRisheng 已提交
2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325
    auto temp1 = static_cast<T>(1) /
                 (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
    auto out = x * temp1;
    auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
    dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.pow(static_cast<T>(factor));
  }
};

template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
  float factor;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"factor", &factor}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2326
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360
    dx.device(d) = dout * static_cast<T>(factor) *
                   x.pow(static_cast<T>(factor) - static_cast<T>(1));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.floor();
  }
};

// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.round();
  }
};

// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.ceil();
  }
};

2361 2362 2363 2364 2365 2366 2367 2368
template <typename T>
struct NegativeFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = -x;
  }
};

Y
YuanRisheng 已提交
2369 2370 2371 2372 2373 2374 2375
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2376 2377
  void operator()(
      Device d, X x UNUSED, Out out, dOut dout UNUSED, dX dx) const {
Y
YuanRisheng 已提交
2378 2379 2380 2381 2382 2383 2384 2385
    dx.device(d) = static_cast<T>(0) * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kNoDeps;
  }
};

Y
YuanRisheng 已提交
2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* dX,
                  const DenseTensor* ddX,
                  DenseTensor* dOut,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
    // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
    // calculate dy first, so ddy can inplace ddx
    if (dOut) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
      dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
      ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* Out,
                  const DenseTensor* dX,
                  const DenseTensor* ddX,
                  DenseTensor* dOut,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
    auto out = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));

    // rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
    if (dOut) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
      dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
      ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
        (x < static_cast<T>(0))
            .select(static_cast<T>(alpha) *
                        ((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
                    x);
  }
};

template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
            typename dX>
2482
  void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
Y
YuanRisheng 已提交
2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533
    auto temp_a_pos = static_cast<T>(alpha > 0);
    auto temp_a_neg = static_cast<T>(alpha <= 0);
    auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
    auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();

    // dx = dout, if alpha > 0 and x > 0
    // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
    // dx = dout , if alpha < 0 and x > 0
    // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
    dx.device(d) =
        dout * temp_a_pos * temp_x_pos +
        dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
        dout * temp_a_neg * temp_x_pos +
        dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
  float alpha;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* dOut,
                  const DenseTensor* ddX,
                  DenseTensor* dX,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));

    if (dX) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
      dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
                      (x / static_cast<T>(alpha)).exp() *
                      (x <= static_cast<T>(0)).template cast<T>();
    }

    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
2534 2535 2536 2537
      ddout.device(*d) = ddx * ((x > static_cast<T>(0)).template cast<T>() +
                                (x / static_cast<T>(alpha)).exp() *
                                    (x <= static_cast<T>(0)).template cast<T>())
                                   .template cast<T>();
Y
YuanRisheng 已提交
2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev,
                  const DenseTensor* X,
                  const DenseTensor* dOut,
                  const DenseTensor* ddX,
                  DenseTensor* dX,
                  DenseTensor* ddOut) const {
    auto* d = dev.eigen_device();
    auto ddx = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
    auto x = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
    // square GradGrad: ddy=2x*ddx, dx=2dy*ddx
    // calculate dx first, so ddy can inplace ddx
    if (dX) {
      auto dx = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
      auto dout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
      dx.device(*d) = ddx * static_cast<T>(2) * dout;
    }
    if (ddOut) {
      auto ddout = EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
      ddout.device(*d) = ddx * static_cast<T>(2) * x;
    }
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

2575
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627

template <typename T>
struct CudaLogitFunctor : public BaseActivationFunctor<T> {
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  MT zero = static_cast<MT>(0.0f);
  MT one = static_cast<MT>(1.0f);
  float eps;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"eps", &eps}};
  }

  // logit(x) = ln(x/(1-x))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MT x = static_cast<MT>(arg_x);
    MT y = min(x, (one - static_cast<MT>(eps)));
    y = max(y, static_cast<MT>(eps));

    if (!eps) {
      y = x < zero || x > one ? static_cast<T>(NAN) : log(y / (one - y));
    } else {
      y = log(y / (one - y));
    }
    return static_cast<T>(y);
  }
};

template <typename T>
struct CudaLogitGradFunctor : public BaseActivationFunctor<T> {
  using MT = typename phi::dtype::MPTypeTrait<T>::Type;

  float eps;
  MT zero = static_cast<MT>(0.0f);
  MT one = static_cast<MT>(1.0f);

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"eps", &eps}};
  }
  // logit(x)' = 1/(x*(1-x))
  __device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
    MT x = static_cast<MT>(arg_x);
    MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
                ? zero
                : (static_cast<MT>(dout) / (x * (one - x)));
    return static_cast<T>(dx);
  }
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // relu(x) = max(x, 0)
  __device__ __forceinline__ T operator()(const T x) const {
    return x > zero ? x : zero;
  }
};

template <typename T>
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

  // dx = dout * (out > 0)
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return out > zero ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // cos(x) = cos(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(cos(x));
  }
};

template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout * (-sin(x))
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(-dout * sin(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

2678 2679
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
2680
  // exp(x) = expf(x)
2681 2682 2683 2684
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

  __device__ __forceinline__ U operator()(const T x) const {
    return static_cast<U>(expf(static_cast<float>(x)));
2685 2686
  }
};
2687

2688 2689
template <>
struct CudaExpFunctor<double> : public BaseActivationFunctor<double> {
2690
  // exp(x) = exp(x)
2691 2692
  __device__ __forceinline__ double operator()(const double x) const {
    return exp(x);
2693 2694 2695
  }
};

2696 2697 2698 2699 2700 2701 2702
template <typename T>
struct CudaSeluFunctor : public BaseActivationFunctor<T> {
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale", &scale}, {"alpha", &alpha}};
  }

  __device__ __forceinline__ T operator()(const T x) const {
2703 2704 2705 2706
    using MT =
        typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type;
    MT res = static_cast<MT>(x);
    if (x <= zero) {
2707 2708 2709
      res = alpha * expf(res) - alpha;
    }
    res *= scale;
2710
    return static_cast<T>(res);
2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741
  }

 private:
  float scale;
  float alpha;
  T zero = static_cast<T>(0.0f);
};

template <>
struct CudaSeluFunctor<double> : public BaseActivationFunctor<double> {
  typename BaseActivationFunctor<double>::AttrPair GetAttrs() {
    return {{"scale", &scale}, {"alpha", &alpha}};
  }

  __device__ __forceinline__ double operator()(const double x) const {
    double res = x;
    double alpha_cast = static_cast<double>(alpha);
    double scale_cast = static_cast<double>(scale);
    if (res <= zero) {
      res = alpha_cast * exp(res) - alpha_cast;
    }
    res *= scale_cast;
    return res;
  }

 private:
  float scale;
  float alpha;
  double zero = static_cast<double>(0.0f);
};

2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773
template <typename T>
struct CudaSquareFunctor : public BaseActivationFunctor<T> {
  // square(x) = x * x
  __device__ __forceinline__ T operator()(const T x) const { return x * x; }
};

template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
  T two = static_cast<T>(2.0f);

  // dx = dout * 2 * x
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout * two * x;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return dout * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
2774 2775
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
2776

2777 2778 2779
  __device__ __forceinline__ T operator()(const T x) const {
    return static_cast<T>(one / static_cast<MPType>(x));
  }
2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795
};

template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
  // dx = -dout * out^2
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return -dout * out * out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
2796 2797 2798 2799 2800 2801 2802
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;

  // expm1(x) = expm1f(x)
  __device__ __forceinline__ U operator()(const T x) const {
    return static_cast<U>(::expm1f(static_cast<float>(x)));
  }
};
2803

2804 2805
template <>
struct CudaExpm1Functor<double> : public BaseActivationFunctor<double> {
2806
  // expm1(x) = expm1(x)
2807 2808
  __device__ __forceinline__ double operator()(const double x) const {
    return ::expm1(x);
2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823
  }
};

template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
  // dx = dout * out
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return dout * out + dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988 2989 2990 2991 2992 2993 2994 2995 2996 2997 2998 2999 3000 3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 3021 3022 3023 3024 3025 3026 3027 3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039 3040 3041 3042 3043 3044 3045
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // sin(x) = sin(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(sin(x));
  }
};

template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout * cos(x)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * cos(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // tan(x) = tan(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(tan(x));
  }
};

template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout / cos(x)^2
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout / (cos(x) * cos(x)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // asin(x) = asin(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(asin(x));
  }
};

template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout / sqrt(1 - x^2)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // acos(x) = acos(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(acos(x));
  }
};

template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = -dout / sqrt(1 - x^2)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(-dout / sqrt(one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // cosh(x) = cosh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(cosh(x));
  }
};

template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout * sinh(x)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * sinh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // sinh(x) = sinh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(sinh(x));
  }
};

template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout * cosh(x)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * cosh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAcoshFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // Acosh(x) = acosh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(acosh(x));
  }
};

template <typename T>
struct CudaAcoshGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  // dx = dout * 1 / sqrt(x^2 - 1)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / sqrt(x * x - one));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAsinhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // Asinh(x) = asinh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(asinh(x));
  }
};

template <typename T>
struct CudaAsinhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * 1/sqrt(x^2 + 1)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / sqrt(x * x + one));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // Atanh(x) = atanh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(atanh(x));
  }
};

3046 3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128 3129 3130 3131 3132 3133 3134 3135
template <typename T>
struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // stanh(x) = b * tanh(a * x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    return static_cast<T>(b * tanh(a * x));
  }
};

template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float scale_a;
  float scale_b;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
  }

  // dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(scale_a);
    MPType b = static_cast<MPType>(scale_b);
    MPType temp = tanh(a * x);
    return static_cast<T>(dout * a * b * (one - temp * temp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // softplus(x) = beta * x > threshold ? x : log(1 + exp(beta * x)) / beta
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
  }
};

template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float beta;
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}, {"threshold", &threshold}};
  }

  // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType b = static_cast<MPType>(beta);
    MPType t = static_cast<MPType>(threshold);
    MPType x_beta = x * beta;
    return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3136 3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147 3148 3149 3150
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  // dx = dout * 1/(1- x^2)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * one / (one - x * x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165 3166 3167 3168 3169 3170 3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183 3184 3185 3186 3187 3188
template <typename T>
struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // sqrt(x) = sqrt(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(sqrt(x));
  }
};

template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
  T one_half = static_cast<T>(0.5f);

  // dx = dout * 0.5 / out
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return one_half * dout / out;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // rsqrt(x) = rsqrt(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(rsqrt(x));
  }
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
3189 3190
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType minus_one_half = static_cast<MPType>(-0.5f);
3191 3192

  // dx = -0.5 * dout * out^3
3193 3194 3195 3196 3197
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
    return static_cast<T>(minus_one_half * dout * out * out * out);
3198 3199 3200 3201 3202 3203 3204
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

3205 3206 3207 3208 3209 3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // atan(x) = atan(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(atan(x));
  }
};

template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x^2)
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout / (one + x * x);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253
template <typename T>
struct CudaTanhFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // tanh(x) = tanh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(tanh(x));
  }
};

template <typename T>
struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * (1 - out^2)
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return dout * (one - out * out);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
3254
struct CudaHardTanhFunctor : public BaseActivationFunctor<T> {
3255 3256 3257 3258 3259 3260 3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // brelu(x) = min(max(x, t_min), t_max)
  __device__ __forceinline__ T operator()(const T x) const {
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    T temp_max = x > t_min_cast ? x : t_min_cast;
    T temp_min = temp_max < t_max_cast ? temp_max : t_max_cast;
    return temp_min;
  }
};

3272 3273 3274 3275 3276 3277 3278 3279 3280 3281 3282 3283 3284 3285 3286 3287 3288 3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303 3304 3305 3306 3307 3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318 3319 3320
template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // mish(x) = x * tanh(softplus(x))
  // softplus(x) = x, if x > threshold
  //             = ln(1 + exp(x)), otherwise
  // Inputs: args[0], the input x
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    return static_cast<T>(x * tanh(sp));
  }
};

template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
  // sp = softplus(x)
  // Inputs: args[0], the input dout
  //         args[1], the input x
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
    MPType gsp =
        (x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
    MPType tsp = tanh(sp);
    return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3321
template <typename T>
3322
struct CudaHardTanhGradFunctor : public BaseActivationFunctor<T> {
3323 3324 3325 3326 3327 3328 3329 3330 3331 3332 3333 3334 3335 3336 3337 3338 3339 3340 3341 3342 3343 3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354 3355 3356 3357 3358 3359 3360 3361 3362 3363 3364 3365 3366 3367 3368 3369 3370 3371 3372
  T zero = static_cast<T>(0.0f);
  float t_min;
  float t_max;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"t_min", &t_min}, {"t_max", &t_max}};
  }

  // dx = (x > t_min && x < t_max) ? dout : 0
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    T t_min_cast = static_cast<T>(t_min);
    T t_max_cast = static_cast<T>(t_max);
    return (x > t_min_cast && x < t_max_cast) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // thresholded_relu(x) = x > threshold ? x : 0
  __device__ __forceinline__ T operator()(const T x) const {
    return x > static_cast<T>(threshold) ? x : zero;
  }
};

template <typename T>
struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = x > threshold ? dout : 0
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return x > static_cast<T>(threshold) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3373 3374 3375 3376 3377 3378 3379 3380 3381 3382 3383 3384 3385 3386 3387 3388 3389 3390 3391 3392
template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // relu6(x) = min(max(0, x), 6)
  __device__ __forceinline__ T operator()(const T x) const {
    T t = static_cast<T>(threshold);
    return x <= zero ? zero : (x < t ? x : t);
  }
};

template <typename T>
struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);

3393
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() { return {{}}; }
3394 3395 3396

  // dx = (out > 0 && out < t) ? dout : 0
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
3397
    float threshold = 6;
3398 3399 3400 3401 3402 3403 3404 3405
    T t = static_cast<T>(threshold);
    return (out > zero && out < t) ? dout : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};
3406 3407 3408 3409 3410 3411 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 3431 3432 3433 3434 3435 3436
template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // leakyrelu(x) = x > 0 ? x : alpha * x
  __device__ __forceinline__ T operator()(const T x) const {
    return x > zero ? x : static_cast<T>(alpha) * x;
  }
};

template <typename T>
struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout * (x > 0 ? 1 : alpha)
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return x > zero ? dout : static_cast<T>(alpha) * dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
Y
YuanRisheng 已提交
3437 3438 3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449 3450 3451 3452 3453 3454 3455 3456 3457 3458 3459 3460 3461 3462 3463 3464 3465 3466 3467 3468 3469 3470 3471 3472 3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487 3488 3489 3490 3491 3492 3493 3494 3495 3496 3497 3498 3499 3500 3501 3502 3503 3504 3505 3506 3507 3508 3509 3510 3511 3512 3513 3514 3515 3516 3517 3518 3519 3520 3521 3522 3523 3524 3525 3526 3527 3528 3529 3530 3531 3532 3533 3534 3535 3536 3537 3538 3539 3540 3541 3542 3543 3544 3545 3546 3547 3548 3549 3550 3551 3552 3553 3554 3555 3556 3557 3558 3559 3560 3561 3562 3563 3564 3565 3566 3567 3568 3569 3570 3571 3572 3573 3574 3575 3576 3577 3578 3579 3580 3581 3582 3583 3584 3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598 3599 3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612 3613 3614 3615 3616 3617 3618 3619 3620 3621 3622 3623 3624 3625 3626 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639

template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // softshrink(x) = x - lambda, if x > lambda;
  //                 x + lambda, if x < -lambda;
  //                 0, otherwise.
  __device__ __forceinline__ T operator()(const T x) const {
    T l = static_cast<T>(lambda);
    T temp1 = static_cast<T>(x > l);
    T temp2 = static_cast<T>(x < -l);
    return temp1 * (x - l) + temp2 * (x + l);
  }
};

template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float lambda;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"lambda", &lambda}};
  }

  // dx = dout, if x > lambda or x < -lambda else 0
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    T l = static_cast<T>(lambda);
    return (x >= -l && x <= l) ? zero : dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // tanhshrink(x) = x - tanh(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(x - tanh(x));
  }
};

template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // dx = dout * tanh(x)^2
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(dout * tanh(x) * tanh(x));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
  __device__ __forceinline__ T operator()(const T x) const {
    T t = static_cast<T>(threshold);
    return (x > -t && x < t) ? zero : x;
  }
};

template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

  // dx = (x > -threshold && x < threshold) ? 0 : dout
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    T t = static_cast<T>(threshold);
    return (x > -t && x < t) ? zero : dout;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename phi::dtype::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // elu(x) = x, if x > 0
  // elu(x) = alpha * (e^x - 1), if x <= 0
  __device__ __forceinline__ T operator()(const T arg_x) const {
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x) - one);
    CT res = x > zero ? x : temp;
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // case 1: alpha >= 0
  // dx = dout, if out > 0
  // dx = dout * (out + alpha), if out <= 0
  __device__ __forceinline__ T operator()(T arg_dout, T arg_out) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
    MPType a = static_cast<MPType>(alpha);
    MPType out_pos = static_cast<MPType>(out > zero);
    MPType out_neg = static_cast<MPType>(out <= zero);
    return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // case 2: alpha < 0
  // dx = dout, if x > 0
  // dx = dout * (out + alpha), if x <=0
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_out,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType out = static_cast<MPType>(arg_out);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType x_pos = static_cast<MPType>(x > zero);
    MPType x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // silu(x) = x / (1 + exp(-x))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(x / (one + exp(-x)));
  }
};

template <typename T>
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType temp = one / (one + exp(-x));
    return static_cast<T>(dout * (temp * (one + x * (one - temp))));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3640 3641 3642 3643 3644 3645 3646 3647 3648 3649 3650 3651 3652 3653 3654 3655 3656 3657 3658 3659 3660 3661 3662 3663 3664
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // softsign(x) = x / (1 + abs(x))
  __device__ __forceinline__ T operator()(const T x) const {
    // Using abs directly will cause namespace conflict
    return x / (one + (x > -x ? x : -x));
  }
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + abs(x))^2
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    // Using abs directly will cause namespace conflict
    T temp = one + (x > -x ? x : -x);
    return dout / (temp * temp);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Y
YuanRisheng 已提交
3665 3666 3667 3668 3669 3670 3671 3672 3673 3674 3675 3676 3677 3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689 3690 3691 3692 3693 3694 3695 3696 3697 3698 3699 3700 3701 3702 3703 3704 3705 3706 3707 3708 3709 3710 3711 3712 3713 3714 3715 3716 3717 3718 3719 3720 3721 3722 3723 3724 3725 3726 3727 3728 3729 3730 3731 3732 3733 3734 3735 3736 3737 3738 3739 3740 3741 3742 3743 3744 3745 3746 3747 3748 3749 3750 3751 3752 3753 3754 3755 3756 3757 3758 3759 3760 3761 3762 3763 3764 3765 3766 3767 3768 3769 3770
template <typename T>
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(one / (one + exp(-x)));
  }
};

template <typename T>
struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout * out * (1 - out)
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return dout * out * (one - out);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // logsigmoid(x) = log(1 / (1 + exp(-x)))
  // For numerical stability,
  // logsigmoid(x) =
  //          - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    MPType temp = x > zero ? zero : -x;
    return static_cast<T>(-temp - log(exp(-temp) + exp(-x - temp)));
  }
};

template <typename T>
struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);

  // dx = dout * exp(-x) / (1 + exp(-x))
  // For numerical stability:
  // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,
  // 0)))
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType temp1 = x > zero ? zero : -x;
    MPType temp2 = exp(-x - temp1);
    return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // hard_sigmoid(x) = 0, when x <= -3
  //                   1, when x >= 3
  //                   x * slope + offset, otherwise
  __device__ __forceinline__ T operator()(const T x) const {
    T temp = x * static_cast<T>(slope) + static_cast<T>(offset);
    T temp_max = temp > zero ? temp : zero;
    T temp_min = temp_max < one ? temp_max : one;
    return temp_min;
  }
};

template <typename T>
struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
  T zero = static_cast<T>(0.0f);
  T one = static_cast<T>(1.0f);
  float slope;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"slope", &slope}, {"offset", &offset}};
  }

  // dx = (out > 0 && out < 1) ? dout * slope : 0
  __device__ __forceinline__ T operator()(const T dout, const T out) const {
    return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
};

3771 3772 3773 3774 3775 3776 3777 3778 3779 3780 3781 3782 3783 3784 3785 3786 3787 3788 3789 3790
template <typename T>
__device__ __forceinline__
    std::conditional_t<std::is_integral<T>::value, float, T>
    log_local(T x) {
  static_assert(!std::is_same<T, double>::value,
                "this template must be used with float or less precise type");

#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  // use __logf fast approximation for peak bandwidth
  return __logf(x);
#else
  return ::log(x);
#endif
}

template <>
__device__ __forceinline__ double log_local<double>(double x) {
  return ::log(x);
}

3791 3792 3793
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3794
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
3795 3796

  // log(x) = log(x)
3797
  __device__ __forceinline__ U operator()(const T arg_x) const {
3798
    MPType x = static_cast<MPType>(arg_x);
3799
    return static_cast<U>(log_local(x));
3800 3801 3802 3803 3804 3805 3806 3807 3808 3809 3810 3811 3812 3813 3814 3815 3816
  }
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
  // dx = dout / x
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout / x;
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
3817
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
3818 3819

  // log1p(x) = log(1 + x)
3820
  __device__ __forceinline__ U operator()(const T arg_x) const {
3821
    MPType x = static_cast<MPType>(arg_x);
3822
    return static_cast<U>(log_local(one + x));
3823 3824 3825 3826 3827 3828 3829 3830 3831 3832 3833 3834 3835 3836 3837
  }
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
  T one = static_cast<T>(1.0f);

  // dx = dout / (1 + x)
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout / (one + x);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3838 3839 3840 3841 3842 3843 3844 3845 3846 3847 3848 3849 3850 3851 3852 3853 3854 3855 3856 3857
template <typename T>
__device__ __forceinline__
    std::conditional_t<std::is_integral<T>::value, float, T>
    log2_local(T x) {
  static_assert(!std::is_same<T, double>::value,
                "this template must be used with float or less precise type");

#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  // use __logf fast approximation for peak bandwidth
  return __log2f(x);
#else
  return ::log2(x);
#endif
}

template <>
__device__ __forceinline__ double log2_local<double>(double x) {
  return ::log2(x);
}

3858 3859 3860
template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3861
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
3862 3863

  // log2(x) = log2(x)
3864
  __device__ __forceinline__ U operator()(const T arg_x) const {
3865
    MPType x = static_cast<MPType>(arg_x);
3866
    return static_cast<U>(log2_local(x));
3867 3868 3869 3870 3871 3872 3873 3874 3875 3876 3877 3878 3879 3880 3881 3882
  }
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  T log_two = static_cast<T>(log(static_cast<MPType>(2.0f)));

  // dx = dout / (x * log(2))
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout / (x * log_two);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

3883 3884 3885 3886 3887 3888 3889 3890 3891 3892 3893 3894 3895 3896 3897 3898 3899 3900 3901 3902
template <typename T>
__device__ __forceinline__
    std::conditional_t<std::is_integral<T>::value, float, T>
    log10_local(T x) {
  static_assert(!std::is_same<T, double>::value,
                "this template must be used with float or less precise type");

#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  // use __logf fast approximation for peak bandwidth
  return __log10f(x);
#else
  return ::log10(x);
#endif
}

template <>
__device__ __forceinline__ double log10_local(double x) {
  return ::log10(x);
}

3903 3904 3905
template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3906
  using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
3907 3908

  // log10(x) = log10(x)
3909
  __device__ __forceinline__ U operator()(const T arg_x) const {
3910
    MPType x = static_cast<MPType>(arg_x);
3911
    return static_cast<U>(log10_local(x));
3912 3913 3914 3915 3916 3917 3918 3919 3920 3921 3922 3923 3924 3925 3926 3927
  }
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  T log_ten = static_cast<T>(log(static_cast<MPType>(10.0f)));

  // dx = dout / (x * log(10))
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
    return dout / (x * log_ten);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Y
YuanRisheng 已提交
3928 3929 3930 3931
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);
3932
  float beta = 1.0;
Y
YuanRisheng 已提交
3933 3934 3935 3936 3937 3938 3939 3940 3941 3942 3943 3944 3945 3946 3947 3948 3949 3950

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"beta", &beta}};
  }

  // swish(x) = x / (1 + exp(-beta * x))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    MPType b = static_cast<MPType>(beta);
    return static_cast<T>(x / (one + exp(-b * x)));
  }
};

template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

3951
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() { return {{}}; }
Y
YuanRisheng 已提交
3952 3953 3954 3955

  // dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
3956
    float beta = 1.0;
Y
YuanRisheng 已提交
3957 3958 3959 3960 3961 3962 3963 3964 3965 3966 3967 3968 3969 3970 3971
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType b = static_cast<MPType>(beta);
    MPType temp1 = one / (one + exp(-b * x));
    MPType out = x * temp1;
    MPType temp2 = b * out;
    MPType temp3 = temp1 * (one - temp2);
    return static_cast<T>(dout * (temp2 + temp3));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
Z
Zhang Ting 已提交
3972 3973
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  const MPType zero = static_cast<MPType>(0.0f);
Y
YuanRisheng 已提交
3974 3975 3976 3977 3978 3979 3980 3981 3982 3983 3984 3985 3986
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // hard_swish(x) = 0, when x <= -offset
  //                 x , when x >= threshold - offset
  //                 x * (x + offset) / scale, otherwise
  // threshold = scale = 6, offset = 3 by default
  __device__ __forceinline__ T operator()(const T x) const {
Z
Zhang Ting 已提交
3987 3988 3989 3990
    const MPType x_t = static_cast<MPType>(x);
    const MPType temp_max = std::max(x_t + static_cast<MPType>(offset), zero);
    const MPType temp_min = std::min(temp_max, static_cast<MPType>(threshold));
    return static_cast<T>(temp_min * x_t / static_cast<MPType>(scale));
Y
YuanRisheng 已提交
3991 3992 3993 3994 3995
  }
};

template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
Z
Zhang Ting 已提交
3996 3997 3998 3999
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  const MPType zero = static_cast<MPType>(0.0f);
  const MPType one = static_cast<MPType>(1.0f);
  const MPType two = static_cast<MPType>(2.0f);
Y
YuanRisheng 已提交
4000 4001 4002 4003 4004 4005 4006 4007 4008 4009 4010 4011 4012
  float threshold;
  float scale;
  float offset;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
  }

  // dx = 0, when x <= -offset
  //      dout , when x >= threshold - offset
  //      dout * (2 * x / scale + offset / scale), otherwise
  // threshold = scale = 6, offset = 3 by default
  __device__ __forceinline__ T operator()(const T dout, const T x) const {
Z
Zhang Ting 已提交
4013 4014 4015 4016 4017 4018 4019 4020 4021 4022 4023
    const MPType dout_t = static_cast<MPType>(dout);
    const MPType x_t = static_cast<MPType>(x);
    const MPType offset_t = static_cast<MPType>(offset);
    const MPType scale_t = static_cast<MPType>(scale);
    const MPType temp1 = static_cast<MPType>(x_t + offset_t > zero);
    const MPType temp2 =
        static_cast<MPType>(x_t + offset_t < static_cast<MPType>(threshold));

    return static_cast<T>(
        dout_t *
        (temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2));
Y
YuanRisheng 已提交
4024 4025 4026 4027 4028 4029 4030 4031 4032 4033 4034 4035 4036 4037 4038 4039 4040 4041 4042 4043 4044 4045 4046 4047 4048 4049 4050 4051 4052 4053 4054 4055 4056 4057 4058 4059 4060 4061 4062 4063 4064 4065 4066 4067 4068 4069 4070 4071 4072 4073
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // ceil(x) = ceil(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(ceil(x));
  }
};

template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // floor(x) = floor(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(floor(x));
  }
};

template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

  // round(x) = round(x)
  __device__ __forceinline__ T operator()(const T arg_x) const {
    MPType x = static_cast<MPType>(arg_x);
    return static_cast<T>(round(x));
  }
};

// GradFunctor for ceil, floor and round
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
  __device__ __forceinline__ T operator()(const T x) const {
    return static_cast<T>(0.0f);
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kNoDeps;
  }
};

Y
YuanRisheng 已提交
4074 4075 4076 4077 4078 4079 4080 4081 4082 4083 4084 4085 4086 4087 4088 4089 4090 4091 4092 4093 4094 4095 4096 4097 4098 4099 4100 4101 4102 4103 4104 4105 4106 4107 4108 4109 4110 4111 4112 4113 4114 4115 4116 4117 4118 4119 4120 4121 4122 4123 4124 4125 4126
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
  using CT = typename phi::dtype::MPTypeTrait<T>::Type;
  CT zero = static_cast<CT>(0.0f);
  CT one = static_cast<CT>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
  __device__ __forceinline__ T operator()(const T arg_x) const {
    CT x = static_cast<CT>(arg_x);
    CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
    CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
    return static_cast<T>(res);
  }
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType zero = static_cast<MPType>(0.0f);
  MPType one = static_cast<MPType>(1.0f);
  float alpha;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"alpha", &alpha}};
  }

  // dx = dout, if alpha > 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
  // dx = dout , if alpha < 0 and x > 0
  // dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
  __device__ __forceinline__ T operator()(const T arg_dout,
                                          const T arg_x) const {
    MPType dout = static_cast<MPType>(arg_dout);
    MPType x = static_cast<MPType>(arg_x);
    MPType a = static_cast<MPType>(alpha);
    MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
    MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
    MPType temp_x_pos = static_cast<MPType>(x > zero);
    MPType temp_x_neg = static_cast<MPType>(x <= zero);
    return static_cast<T>(
        dout *
        (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
         temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

4127 4128 4129 4130
#endif

}  // namespace funcs
}  // namespace phi