activation_op.cc 44.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/op_version_registry.h"
24
#include "paddle/fluid/operators/common_infer_shape_functions.h"
25
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
26
#include "paddle/phi/backends/dynload/port.h"
Q
qijun 已提交
27

A
Adam 已提交
28 29
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
35 36
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
37 38
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
53
  }
D
dzhwinter 已提交
54

H
hong 已提交
55 56
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
57
 public:
H
hong 已提交
58
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
59 60

 protected:
61
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
62 63 64 65
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
66

A
Adam 已提交
67 68
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
69 70
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
71
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
72
      op->SetInput("X", this->Input("X"));  // x
73 74 75 76
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
77
      op->SetInput("Out", this->Output("Out"));  // out
78
    }
D
dzhwinter 已提交
79
  }
80
};
D
dzhwinter 已提交
81

82 83 84
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
85
  auto data_type = oper.IndicateVarDataType(ctx, name);
86 87 88 89 90 91 92 93 94 95
  // FIXME(liuwei1031) temporarily disable the code to unblock users
  // TODO(liuwei1031) figure out the reason behind
  // https://github.com/PaddlePaddle/Paddle/issues/16096
  // and re-enable this in the future
  // #ifdef PADDLE_WITH_CUDA
  //   auto it1 = oper.Attrs().find("use_cudnn");
  //   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
  //     library = framework::LibraryType::kCUDNN;
  //   }
  // #endif
96 97 98 99 100 101 102 103

  // NOTE(jiahongyu): Activation ops have attribute use_cudnn, but cudnn kernels
  // are temporarily disabled. Therefore, cudnn kernel also needs to fallback to
  // plain GPU kernel temporarily. When above codes are uncommented, below
  // fallback codes can be deleted safely.
  if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
    oper.SetDnnFallback(true);
  }
104
  return framework::OpKernelType(data_type, ctx.GetPlace());
105 106
}

Q
qijun 已提交
107 108 109 110
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

111
  void InferShape(framework::InferShapeContext* ctx) const override {
112
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
113
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
114
  }
115

116
 protected:
117 118 119 120
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
121 122
};

C
chengduo 已提交
123 124 125
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
126
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
127
      const override {
128 129
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
130 131 132
  }
};

Q
qijun 已提交
133 134 135 136
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

137
  void InferShape(framework::InferShapeContext* ctx) const override {
138 139 140
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
141
  }
142

143
 protected:
144 145
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
146
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
147
  }
Q
qijun 已提交
148 149
};

D
dzhwinter 已提交
150
UNUSED constexpr char SigmoidDoc[] = R"DOC(
151
Sigmoid Activation
K
Kexin Zhao 已提交
152

153
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
154

D
dzhwinter 已提交
155
)DOC";
Q
qijun 已提交
156

D
dzhwinter 已提交
157
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
158
Relu Activation Operator.
K
Kexin Zhao 已提交
159

160
$$out = \max(x, 0)$$
K
Kexin Zhao 已提交
161

D
dzhwinter 已提交
162
)DOC";
K
Kexin Zhao 已提交
163

D
dzhwinter 已提交
164
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
165
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
166

Y
Yan Chunwei 已提交
167
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
K
Kexin Zhao 已提交
170

D
dzhwinter 已提交
171
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
172
Sqrt Activation Operator.
K
Kexin Zhao 已提交
173

N
Noel 已提交
174
$$out=\\sqrt{x}=x^{1/2}$$
175

176 177
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
178

D
dzhwinter 已提交
179
)DOC";
180

Z
zhoukunsheng 已提交
181 182 183 184 185
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

186
$$out = \\frac{1}{\\sqrt{x}}$$
Z
zhoukunsheng 已提交
187 188 189

)DOC";

D
dzhwinter 已提交
190
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
191
Log Activation Operator.
K
Kexin Zhao 已提交
192

193
$$out = \ln(x)$$
K
Kexin Zhao 已提交
194 195 196

Natural logarithm of x.

D
dzhwinter 已提交
197 198
)DOC";

D
dzhwinter 已提交
199
UNUSED constexpr char SquareDoc[] = R"DOC(
200
The OP square each elements of the inputs.
D
dzhwinter 已提交
201

202
$$out = x^2$$
203

D
dzhwinter 已提交
204 205
)DOC";

D
dzhwinter 已提交
206
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
207 208
Softsign Activation Operator.

209
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
210 211 212 213

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
214
 public:
Y
Yu Yang 已提交
215
  void Make() override {
W
Wilber 已提交
216 217 218 219 220 221 222 223
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
K
Kexin Zhao 已提交
224
    AddComment(R"DOC(
D
dzhwinter 已提交
225
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
226

W
Wilber 已提交
227
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
228 229

)DOC");
230 231 232
  }
};

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "Input of Softplus operator, an N-D Tensor, with data type "
             "float32, float64 or float16.");
    AddOutput(
        "Out",
        "Output of Softplus operator, a Tensor with shape same as input.");
    AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
    AddAttr<float>("threshold", "The value of threshold for Softplus.")
        .SetDefault(20.0f);
    AddComment(R"DOC(
:strong:`Softplus Activation Operator`

..  math::
    out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
    \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.

)DOC");
  }
};

D
dzhwinter 已提交
256
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
257
 public:
Y
Yu Yang 已提交
258
  void Make() override {
D
dzhwinter 已提交
259 260 261
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
262
    AddComment(R"DOC(
263 264 265
:strong:`Softshrink Activation Operator`

..  math::
266
    out = \begin{cases}
267 268 269 270
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
271 272

)DOC");
K
kexinzhao 已提交
273 274 275
  }
};

276 277
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
278
  void Make() override {
279 280 281 282 283 284
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
285 286 287 288
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
289
    AddComment(R"DOC(
K
kexinzhao 已提交
290
BRelu Activation Operator.
K
Kexin Zhao 已提交
291

292
$$out = \min(\max(x, t_{min}), t_{max})$$
K
Kexin Zhao 已提交
293 294

)DOC");
295 296 297 298 299
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
300
  void Make() override {
301
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
302
    AddOutput("Out", "Output of SoftRelu operator");
303 304
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
305
    AddComment(R"DOC(
K
kexinzhao 已提交
306
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
307

308
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
309 310

)DOC");
311 312 313
  }
};

314 315
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
316
  void Make() override {
317 318 319 320 321 322
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
323
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
324
    AddComment(R"DOC(
K
kexinzhao 已提交
325
ELU Activation Operator.
K
Kexin Zhao 已提交
326 327 328 329

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

330
$$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
K
Kexin Zhao 已提交
331 332

)DOC");
333 334 335
  }
};

Z
zhupengyang 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
template <typename T>
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("elu_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput("X", this->Input("X"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
    AddAttr<float>("alpha", "The alpha value of CELU").SetDefault(1.0f);
    AddComment(R"DOC(
CELU Activation Operator.

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1704.07483.

$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$

)DOC");
  }
};

374 375
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
376
  void Make() override {
Z
zhupengyang 已提交
377 378 379 380 381 382 383 384
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
385
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
386
    AddComment(R"DOC(
K
kexinzhao 已提交
387
Relu6 Activation Operator.
K
Kexin Zhao 已提交
388

389
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
390 391

)DOC");
392 393 394
  }
};

395 396
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
397
  void Make() override {
398
    AddInput("X", "Input of Pow operator");
399 400 401 402 403
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
404
    AddOutput("Out", "Output of Pow operator");
405
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
406
    AddComment(R"DOC(
K
kexinzhao 已提交
407
Pow Activation Operator.
K
Kexin Zhao 已提交
408

409
$$out = x^{factor}$$
K
Kexin Zhao 已提交
410 411

)DOC");
412 413 414 415 416
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
417
  void Make() override {
418 419
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
420
             " A Tensor with type float32, float64.");
421 422 423
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
424 425
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
426
    AddComment(R"DOC(
K
kexinzhao 已提交
427
STanh Activation Operator.
K
Kexin Zhao 已提交
428

Y
Yan Chunwei 已提交
429
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
430 431

)DOC");
Q
qijun 已提交
432 433 434
  }
};

435 436
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
437
  void Make() override {
438
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
439
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
440 441
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
442
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
443
    AddComment(R"DOC(
Y
yuyang18 已提交
444
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
445

Y
yuyang18 已提交
446
..  math::
K
Kexin Zhao 已提交
447

Y
yuyang18 已提交
448
    out = \begin{cases}
Y
yuyang18 已提交
449
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
450 451
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
452
)DOC");
453 454 455
  }
};

A
Abhinav Arora 已提交
456 457
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
458
  void Make() override {
A
Abhinav Arora 已提交
459
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
460
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
461 462 463 464
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

465
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
466 467 468 469 470

)DOC");
  }
};

471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

513
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
514 515 516 517 518 519 520 521 522

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
523 524 525 526
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
527
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
528 529 530 531
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

532
template <ActBwdOpFwdDeps kDepValue>
533 534 535 536 537
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
538 539
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
540
      if (ctx->HasOutput("DX")) {
541 542 543
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
544
      if (ctx->HasOutput("DDOut")) {
545 546 547
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
548
    }
549 550
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
551
      if (ctx->HasOutput("DOut")) {
552 553 554
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
555 556 557 558
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
559 560 561 562
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
579 580
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
581 582 583 584 585
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
586 587
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
588
      if (ctx->HasOutput("DDOut")) {
589 590 591
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
592 593 594 595 596 597 598 599 600 601
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

602 603 604 605 606 607
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
608 609
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
610 611 612 613 614 615 616 617 618
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
619 620
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
template <typename T>
class SigmoidDoubleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
template <typename T>
class SigmoidTripleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_triple_grad");
    // Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    // D_OutNew, D_DOut, D_DDx               // output
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->Input("DDX"));
    // input3: dout
    op->SetInput("DOut", this->Input("DOut"));
    // input4: d_ddout
    op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
    // input5: d_dout_new
    op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
    op->SetAttrMap(this->Attrs());

    // output: d_dOut, d_OutNew, d_ddx
    op->SetOutput("D_OutNew", this->InputGrad("Out"));
    op->SetOutput("D_DOut", this->InputGrad("DOut"));
    op->SetOutput("D_DDx", this->InputGrad("DDX"));
  }
};

694 695
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
H
hong 已提交
696 697
template <typename T>
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
698
 public:
H
hong 已提交
699
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
700 701

 protected:
702
  void Apply(GradOpPtr<T> op) const override {
703 704
    op->SetType("relu_grad_grad");
    // input1: Out
H
hong 已提交
705
    op->SetInput("Out", this->Input("Out"));
Q
qingqing01 已提交
706
    // input2: ddx
H
hong 已提交
707 708
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
709
    // output: ddy
H
hong 已提交
710
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
711 712 713
  }
};

714 715
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
H
hong 已提交
716
template <typename T>
717
class LeakyReluDoubleGradMaker
H
hong 已提交
718
    : public ::paddle::framework::SingleGradOpMaker<T> {
719
 public:
H
hong 已提交
720
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
721 722

 protected:
723
  void Apply(GradOpPtr<T> op) const override {
724
    op->SetType("leaky_relu_grad_grad");
725 726
    // input1: X
    op->SetInput("X", this->Input("X"));
727
    // X@GRAD@GRAD: ddx
H
hong 已提交
728 729
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
730
    // Out@GRAD@GRAD: ddy
H
hong 已提交
731
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
732 733 734
  }
};

D
Double_V 已提交
735 736 737 738 739 740 741 742
// elu grad: dx=dy if y>0 else alpha*dy*x.exp()
// elu gradgrad: ddx=ddy if y>0 else alpha*ddy*x.exp()
template <typename T>
class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
743
  void Apply(GradOpPtr<T> op) const override {
D
Double_V 已提交
744 745 746 747 748 749 750 751 752 753 754 755 756 757
    op->SetType("elu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
// celu grad: dx=dy if y>0 else dy*(x/alpha).exp()
// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha
template <typename T>
class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("celu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

L
lvmengsi 已提交
781 782
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
H
hong 已提交
783 784
template <typename T>
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
L
lvmengsi 已提交
785
 public:
H
hong 已提交
786
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
L
lvmengsi 已提交
787 788

 protected:
789
  void Apply(GradOpPtr<T> op) const override {
L
lvmengsi 已提交
790
    op->SetType("sqrt_grad_grad");
H
hong 已提交
791 792 793 794 795 796
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
L
lvmengsi 已提交
797 798 799
  }
};

W
whs 已提交
800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx
template <typename T>
class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("rsqrt_grad_grad");
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

819 820
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
H
hong 已提交
821 822
template <typename T>
class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
823
 public:
H
hong 已提交
824
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
825 826

 protected:
827
  void Apply(GradOpPtr<T> op) const override {
828
    op->SetType("square_grad_grad");
H
hong 已提交
829
    op->SetInput("X", this->Input("X"));
830
    // Out@GRAD: dy
H
hong 已提交
831
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
832
    // X@GRAD@GRAD: ddx
H
hong 已提交
833
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
834

H
hong 已提交
835
    op->SetAttrMap(this->Attrs());
836 837

    // X@GRAD: dx
H
hong 已提交
838
    op->SetOutput("DX", this->InputGrad("X"));
839
    // Out@GRAD@GRAD: ddy
H
hong 已提交
840
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
841 842 843
  }
};

844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865
// log Grad: dx = dout / x
// log Grad Grad: ddout = ddx / x; dx = -(dout / x) * (ddx / x)
template <typename T>
class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("log_grad_grad");
    op->SetInput("X", this->Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // X@GRAD: dx
    op->SetOutput("DX", this->InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

866
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
867 868
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
869
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
870
                           {"DDX", "DDOut"});
871 872
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
873

H
hong 已提交
874 875
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
876
 public:
H
hong 已提交
877
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
878 879

 protected:
880
  void Apply(GradOpPtr<T> op) const override {
881
    op->SetType("pow_grad");
H
hong 已提交
882 883 884 885 886
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetInput("FactorTensor", this->Input("FactorTensor"));
    op->SetAttrMap(this->Attrs());
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
905
      const std::string& var_name,
906
      const phi::DenseTensor& tensor,
907 908 909 910
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
911 912
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
933
      const std::string& var_name,
934
      const phi::DenseTensor& tensor,
935 936 937 938
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
939 940
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
941 942
  }
};
943
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
944 945 946 947
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
948
namespace plat = paddle::platform;
949

950 951
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
952 953 954
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
955
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
956 957 958 959
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
960
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
961 962 963 964
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
965
                    ops::ActivationGradOpInplaceInferer);
966

L
Leo Chen 已提交
967 968 969 970 971 972 973 974 975 976
#define REGISTER_ACTIVATION_CPU_KERNEL(                                     \
    act_type, op_name, functor, grad_functor)                               \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type,                                                             \
      ops::ActivationKernel<phi::CPUContext, ops::functor<float>>,          \
      ops::ActivationKernel<phi::CPUContext, ops::functor<double>>);        \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type##_grad,                                                      \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
977

978 979
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
980

981
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
982 983 984 985
REGISTER_ACTIVATION_OP(thresholded_relu,
                       ThresholdedRelu,
                       ThresholdedReluFunctor,
                       ThresholdedReluGradFunctor);
986
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
987 988 989
REGISTER_ACTIVATION_OP(softshrink,
                       SoftShrink,
                       SoftShrinkFunctor,
Y
YuanRisheng 已提交
990
                       SoftShrinkGradFunctor);
991 992 993
REGISTER_ACTIVATION_OP(tanh_shrink,
                       TanhShrink,
                       TanhShrinkFunctor,
Y
YuanRisheng 已提交
994
                       TanhShrinkGradFunctor);
995 996 997 998
REGISTER_ACTIVATION_OP(softsign,
                       Softsign,
                       SoftsignFunctor,
                       SoftsignGradFunctor);
999 1000 1001
REGISTER_ACTIVATION_OP(softplus,
                       Softplus,
                       SoftplusFunctor,
1002 1003 1004
                       SoftplusGradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
1005 1006 1007
REGISTER_ACTIVATION_OP(hard_swish,
                       HardSwish,
                       HardSwishFunctor,
Y
YuanRisheng 已提交
1008 1009
                       HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
1010

1011 1012 1013 1014
/* ==========================    sigmoid register  =============================
 */
// 1. Register Sigmoid Operator
REGISTER_OPERATOR(
1015 1016 1017
    sigmoid,
    ops::ActivationOp,
    ops::SigmoidOpMaker,
1018 1019 1020 1021 1022 1023
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::SigmoidGradFunctor<float>>(),
1024 1025
                     ops::ActFwdInplaceInferer,
                     void>::type);
1026 1027

// 2. Register Sigmoid Grad Operator
1028 1029
REGISTER_OPERATOR(sigmoid_grad,
                  ops::ActivationOpGrad,
1030 1031
                  ops::ActivationGradOpInplaceInferer,
                  ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
1032
                  ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
1033 1034 1035 1036

// 3. Register Sigmoid DoubleGrad Operator
REGISTER_OPERATOR(
    sigmoid_grad_grad,
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
    ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer,
    ops::SigmoidTripleGradMaker<paddle::framework::OpDesc>,
    ops::SigmoidTripleGradMaker<paddle::imperative::OpBase>);

// 4. Register Sigmoid TripleGrad Operator
REGISTER_OPERATOR(sigmoid_triple_grad,
                  ops::ActivationOpTripleGrad<
                      ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
                  ops::ActivationTripleGradOpInplaceInferer);
1047 1048 1049

/* ========================================================================== */

1050
/* ==========================    relu register  ============================= */
1051
REGISTER_OPERATOR(
1052 1053 1054 1055
    relu,
    ops::ActivationOp,
    ops::ReluOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1056 1057 1058 1059
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1060
    ops::ActFwdInplaceInferer);
1061 1062
REGISTER_OPERATOR(relu_grad,
                  ops::ActivationOpGrad,
1063
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1064 1065
                  ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
1066 1067
REGISTER_OPERATOR(
    relu_grad_grad,
1068
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
1069
    ops::ActivationDoubleGradOpInplaceInferer);
1070

1071
/* ========================================================================== */
1072

1073
/* ======================== leaky relu register  ============================ */
1074
REGISTER_OPERATOR(
1075 1076 1077
    leaky_relu,
    ops::ActivationOp,
    ops::LeakyReluOpMaker,
1078
    ops::ActivationOpInferVarType,
H
hong 已提交
1079 1080 1081 1082
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1083
    ops::ActFwdInplaceInferer);
1084 1085
REGISTER_OPERATOR(leaky_relu_grad,
                  ops::ActivationOpGrad,
1086
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1087 1088
                  ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
1089 1090
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
1091
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
1092
    ops::ActivationDoubleGradOpInplaceInferer);
1093 1094 1095

/* ========================================================================== */

D
Double_V 已提交
1096
/* ========================    elu  register     ============================ */
1097 1098 1099
REGISTER_OPERATOR(elu,
                  ops::ActivationOp,
                  ops::ELUOpMaker,
Z
zhupengyang 已提交
1100 1101 1102 1103
                  ops::ActivationOpInferVarType,
                  ops::ELUGradOpMaker<paddle::framework::OpDesc>,
                  ops::ELUGradOpMaker<paddle::imperative::OpBase>,
                  ops::ActFwdInplaceInferer);
1104 1105
REGISTER_OPERATOR(elu_grad,
                  ops::ActivationOpGrad,
1106
                  ops::ActivationGradOpInplaceInferer,
D
Double_V 已提交
1107 1108 1109 1110 1111
                  ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    elu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
1112
    ops::ActivationDoubleGradOpInplaceInferer);
D
Double_V 已提交
1113 1114 1115

/* ========================================================================== */

1116 1117 1118
/* ========================    celu  register     ============================
 */
REGISTER_OPERATOR(
1119 1120 1121 1122
    celu,
    ops::ActivationOp,
    ops::CELUOpMaker,
    ops::ActivationOpInferVarType,
1123 1124 1125 1126 1127
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1128 1129
REGISTER_OPERATOR(celu_grad,
                  ops::ActivationOpGrad,
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
                  ops::ActivationGradOpInplaceInferer,
                  ops::CELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::CELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    celu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

L
lvmengsi 已提交
1140 1141
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
1142 1143 1144 1145
    sqrt,
    ops::ActivationOp,
    ops::SqrtOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1146 1147 1148 1149
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1150
    ops::ActFwdInplaceInferer);
1151 1152
REGISTER_OPERATOR(sqrt_grad,
                  ops::ActivationOpGrad,
1153
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1154 1155
                  ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
L
lvmengsi 已提交
1156 1157
REGISTER_OPERATOR(
    sqrt_grad_grad,
1158
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
1159
    ops::ActivationDoubleGradOpInplaceInferer);
1160

L
lvmengsi 已提交
1161 1162
/* ========================================================================== */

W
whs 已提交
1163 1164 1165
/* ===========================   rsqrt register  =============================
 */
REGISTER_OPERATOR(
1166 1167 1168 1169
    rsqrt,
    ops::ActivationOp,
    ops::RsqrtOpMaker,
    ops::ActivationOpInferVarType,
W
whs 已提交
1170 1171 1172 1173 1174
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1175 1176
REGISTER_OPERATOR(rsqrt_grad,
                  ops::ActivationOpGrad,
W
whs 已提交
1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
                  ops::ActivationGradOpInplaceInferer,
                  ops::RsqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::RsqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    rsqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

1187 1188
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
1189 1190 1191
    square,
    ops::ActivationOp,
    ops::SquareOpMaker,
1192
    ops::ActivationOpInferVarType,
H
hong 已提交
1193 1194 1195 1196
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1197
    ops::ActFwdInplaceInferer);
1198 1199
REGISTER_OPERATOR(square_grad,
                  ops::ActivationOpGrad,
1200
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1201 1202
                  ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
1203 1204
REGISTER_OPERATOR(
    square_grad_grad,
1205
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
1206
    ops::ActivationDoubleGradOpInplaceInferer);
1207 1208

/* ========================================================================== */
1209 1210 1211 1212

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
1213 1214 1215 1216
    pow,
    ops::PowOp,
    ops::PowOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1217 1218
    ops::PowGradOpMaker<paddle::framework::OpDesc>,
    ops::PowGradOpMaker<paddle::imperative::OpBase>,
1219
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
1220 1221 1222 1223
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(pow_grad,
                  ops::PowOpGrad,
1224
                  ops::ActivationGradOpInplaceInferer);
1225 1226
/* ========================================================================== */

1227 1228
/* ==========================  Log register ==================================*/
REGISTER_OPERATOR(
1229 1230 1231 1232
    log,
    ops::ActivationOp,
    ops::LogOpMaker,
    ops::ActivationOpInferVarType,
1233 1234 1235 1236 1237
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1238 1239
REGISTER_OPERATOR(log_grad,
                  ops::ActivationOpGrad,
1240 1241 1242 1243 1244 1245 1246 1247 1248
                  ops::ActivationGradOpInplaceInferer,
                  ops::LogDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LogDoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    log_grad_grad,
    ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

1268 1269
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
1270 1271
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
1272 1273 1274 1275 1276 1277 1278
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
1279 1280
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
1281
        false));
1282

1283
/* ========================================================================== */