activation_op.cc 59.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/op_version_registry.h"
24
#include "paddle/fluid/operators/common_infer_shape_functions.h"
25
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
26
#include "paddle/phi/backends/dynload/port.h"
Q
qijun 已提交
27

A
Adam 已提交
28 29
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
35 36
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
37 38
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
53
  }
D
dzhwinter 已提交
54

H
hong 已提交
55 56
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
57
 public:
H
hong 已提交
58
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
59 60

 protected:
61
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
62 63 64 65
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
66

A
Adam 已提交
67 68
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
69 70
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
71
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
72
      op->SetInput("X", this->Input("X"));  // x
73 74 75 76
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
77
      op->SetInput("Out", this->Output("Out"));  // out
78
    }
D
dzhwinter 已提交
79
  }
80
};
D
dzhwinter 已提交
81

82 83 84
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
85
  auto data_type = oper.IndicateVarDataType(ctx, name);
86 87 88 89 90 91 92 93 94 95 96
  // FIXME(liuwei1031) temporarily disable the code to unblock users
  // TODO(liuwei1031) figure out the reason behind
  // https://github.com/PaddlePaddle/Paddle/issues/16096
  // and re-enable this in the future
  // #ifdef PADDLE_WITH_CUDA
  //   auto it1 = oper.Attrs().find("use_cudnn");
  //   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
  //     library = framework::LibraryType::kCUDNN;
  //   }
  // #endif
  return framework::OpKernelType(data_type, ctx.GetPlace());
97 98
}

Q
qijun 已提交
99 100 101 102
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

103
  void InferShape(framework::InferShapeContext* ctx) const override {
104
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
105
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
106
  }
107

108
 protected:
109 110 111 112
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
113 114
};

C
chengduo 已提交
115 116 117
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
118
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
119
      const override {
120 121
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
122 123 124
  }
};

Q
qijun 已提交
125 126 127 128
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

129
  void InferShape(framework::InferShapeContext* ctx) const override {
130 131 132
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
133
  }
134

135
 protected:
136 137
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
138
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
139
  }
Q
qijun 已提交
140 141
};

D
dzhwinter 已提交
142
UNUSED constexpr char SigmoidDoc[] = R"DOC(
143
Sigmoid Activation
K
Kexin Zhao 已提交
144

145
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
146

D
dzhwinter 已提交
147
)DOC";
Q
qijun 已提交
148

M
minghaoBD 已提交
149 150 151 152 153 154
UNUSED constexpr char SiluDoc[] = R"DOC(
Silu Activation Operator

$$out = x * \\frac{1}{1 + e^{-x}}$$
)DOC";

D
dzhwinter 已提交
155
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
156
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
157

158
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
159

D
dzhwinter 已提交
160
)DOC";
161

R
ronnywang 已提交
162 163 164 165 166 167 168
UNUSED constexpr char Expm1Doc[] = R"DOC(
Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base.

$$out = e^x - 1$$

)DOC";

D
dzhwinter 已提交
169
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
170
Relu Activation Operator.
K
Kexin Zhao 已提交
171

172
$$out = \max(x, 0)$$
K
Kexin Zhao 已提交
173

D
dzhwinter 已提交
174
)DOC";
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
177
Tanh Activation Operator.
K
Kexin Zhao 已提交
178

Q
update  
qiaolongfei 已提交
179
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
180

D
dzhwinter 已提交
181
)DOC";
182

D
dzhwinter 已提交
183
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
184
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
185

Y
Yan Chunwei 已提交
186
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
187

D
dzhwinter 已提交
188
)DOC";
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
191
Sqrt Activation Operator.
K
Kexin Zhao 已提交
192

N
Noel 已提交
193
$$out=\\sqrt{x}=x^{1/2}$$
194

195 196
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
197

D
dzhwinter 已提交
198
)DOC";
199

Z
zhoukunsheng 已提交
200 201 202 203 204
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

205
$$out = \\frac{1}{\\sqrt{x}}$$
Z
zhoukunsheng 已提交
206 207 208

)DOC";

D
dzhwinter 已提交
209
UNUSED constexpr char CeilDoc[] = R"DOC(
210
Ceil Operator. Computes ceil of x element-wise.
D
dzhwinter 已提交
211

212 213
..  math::
    out = \left \lceil x \right \rceil
D
dzhwinter 已提交
214

D
dzhwinter 已提交
215
)DOC";
D
dzhwinter 已提交
216

D
dzhwinter 已提交
217
UNUSED constexpr char FloorDoc[] = R"DOC(
218
Floor Activation Operator. Computes floor of x element-wise.
D
dzhwinter 已提交
219

N
Noel 已提交
220
$$out = \\lfloor x \\rfloor$$
D
dzhwinter 已提交
221

D
dzhwinter 已提交
222
)DOC";
D
dzhwinter 已提交
223

D
dzhwinter 已提交
224
UNUSED constexpr char CosDoc[] = R"DOC(
225
Cosine Operator. Computes cosine of x element-wise.
C
add cos  
chengduoZH 已提交
226

Y
Yang Zhang 已提交
227 228
Input range is `(-inf, inf)` and output range is `[-1,1]`.

229 230
..  math::
    out = cos(x)
C
add cos  
chengduoZH 已提交
231

D
dzhwinter 已提交
232
)DOC";
C
add cos  
chengduoZH 已提交
233

J
joejiong 已提交
234 235 236 237 238 239 240 241 242
UNUSED constexpr char TanDoc[] = R"DOC(
Tangent Operator. Computes tangent of x element-wise.

Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.

$$out = tan(x)$$

)DOC";

D
dzhwinter 已提交
243
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
244 245
Sine Activation Operator.

246
$$out = sin(x)$$
C
add sin  
chengduoZH 已提交
247

D
dzhwinter 已提交
248
)DOC";
C
add sin  
chengduoZH 已提交
249

250 251 252 253 254 255 256 257 258 259
UNUSED constexpr char SinhDoc[] = R"DOC(
Sinh Activation Operator.

$$out = sinh(x)$$

)DOC";

UNUSED constexpr char CoshDoc[] = R"DOC(
Cosh Activation Operator.

260 261 262 263
Input range `(-inf, inf)`, output range `(1, inf)`.

..  math::
    out = \frac{exp(x)+exp(-x)}{2}
264 265 266

)DOC";

X
xiaoting 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
UNUSED constexpr char AsinhDoc[] = R"DOC(
Asinh Activation Operator.

$$out = asinh(x)$$

)DOC";

UNUSED constexpr char AcoshDoc[] = R"DOC(
Acosh Activation Operator.

$$out = acosh(x)$$

)DOC";

UNUSED constexpr char AtanhDoc[] = R"DOC(
Atanh Activation Operator.

$$out = atanh(x)$$

)DOC";

D
dzhwinter 已提交
288
UNUSED constexpr char RoundDoc[] = R"DOC(
289
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
290

N
Noel 已提交
291
.. code-block:: text
292 293 294 295 296 297 298 299

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
300

D
dzhwinter 已提交
301
)DOC";
D
dzhwinter 已提交
302

D
dzhwinter 已提交
303
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
304
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
305

306
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
307

D
dzhwinter 已提交
308
)DOC";
309

D
dzhwinter 已提交
310
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
311
Log Activation Operator.
K
Kexin Zhao 已提交
312

313
$$out = \ln(x)$$
K
Kexin Zhao 已提交
314 315 316

Natural logarithm of x.

D
dzhwinter 已提交
317 318
)DOC";

J
joejiong 已提交
319 320 321 322 323 324 325 326 327
UNUSED constexpr char Log2Doc[] = R"DOC(
Log2 Activation Operator.

$$out = \log_2x$$

logarithm of x base to 2.

)DOC";

J
joejiong 已提交
328 329 330 331 332 333 334 335 336
UNUSED constexpr char Log10Doc[] = R"DOC(
Log10 Activation Operator.

$$out = \log_10_x$$

logarithm of x base to 10.

)DOC";

337 338 339 340 341 342 343 344 345
UNUSED constexpr char Log1pDoc[] = R"DOC(
Log Activation Operator.

$out = \ln(x+1)$

Natural logarithm of x.

)DOC";

D
dzhwinter 已提交
346
UNUSED constexpr char SquareDoc[] = R"DOC(
347
The OP square each elements of the inputs.
D
dzhwinter 已提交
348

349
$$out = x^2$$
350

D
dzhwinter 已提交
351 352
)DOC";

D
dzhwinter 已提交
353
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
354 355
Softsign Activation Operator.

356
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
357 358 359

)DOC";

T
tink2123 已提交
360 361 362 363
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
364
    AddOutput("Out", "Tensor, same shape and dtype as input");
T
tink2123 已提交
365
    AddComment(R"DOC(
366
Arccosine Operator.
367

368 369
..  math::
    out = \cos^{-1}(x)
370

T
tink2123 已提交
371 372 373
)DOC");
  }
};
374

T
tink2123 已提交
375 376 377
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
378 379 380
    AddInput("X",
             "Input of asin operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
381
    AddOutput("Out", "Tensor, same shape and dtype as input.");
T
tink2123 已提交
382
    AddComment(R"DOC(
383
Arcsine Operator.
384

385 386
..  math::
    out = \sin^{-1}(x)
387

T
tink2123 已提交
388 389 390
)DOC");
  }
};
391

T
tink2123 已提交
392 393 394
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
395 396 397
    AddInput("X",
             "Input of atan operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
398
    AddOutput("Out", "Tensor, same shape and dtype as input x");
T
tink2123 已提交
399
    AddComment(R"DOC(
400
Arctangent Operator.
401

402 403
..  math::
    out = \tan^{-1}(x)
404

T
tink2123 已提交
405 406 407
)DOC");
  }
};
408

D
dzhwinter 已提交
409
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
410
 public:
Y
Yu Yang 已提交
411
  void Make() override {
W
Wilber 已提交
412 413 414 415 416 417 418 419
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
K
Kexin Zhao 已提交
420
    AddComment(R"DOC(
D
dzhwinter 已提交
421
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
422

W
Wilber 已提交
423
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
424 425

)DOC");
426 427 428
  }
};

429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "Input of Softplus operator, an N-D Tensor, with data type "
             "float32, float64 or float16.");
    AddOutput(
        "Out",
        "Output of Softplus operator, a Tensor with shape same as input.");
    AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
    AddAttr<float>("threshold", "The value of threshold for Softplus.")
        .SetDefault(20.0f);
    AddComment(R"DOC(
:strong:`Softplus Activation Operator`

..  math::
    out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
    \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.

)DOC");
  }
};

D
dzhwinter 已提交
452
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
453
 public:
Y
Yu Yang 已提交
454
  void Make() override {
D
dzhwinter 已提交
455 456 457
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
458
    AddComment(R"DOC(
459 460 461
:strong:`Softshrink Activation Operator`

..  math::
462
    out = \begin{cases}
463 464 465 466
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
467 468

)DOC");
K
kexinzhao 已提交
469 470 471
  }
};

D
dzhwinter 已提交
472
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
473
 public:
Y
Yu Yang 已提交
474
  void Make() override {
D
dzhwinter 已提交
475 476
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
477 478
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
479
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
480
    AddComment(R"DOC(
Y
yuyang18 已提交
481
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
482

Y
yuyang18 已提交
483 484 485 486 487 488
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
489 490

)DOC");
491 492 493
  }
};

494 495
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
496
  void Make() override {
497 498 499 500 501 502
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
503 504 505 506
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
507
    AddComment(R"DOC(
K
kexinzhao 已提交
508
BRelu Activation Operator.
K
Kexin Zhao 已提交
509

510
$$out = \min(\max(x, t_{min}), t_{max})$$
K
Kexin Zhao 已提交
511 512

)DOC");
513 514 515 516 517
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
518
  void Make() override {
519
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
520
    AddOutput("Out", "Output of SoftRelu operator");
521 522
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
523
    AddComment(R"DOC(
K
kexinzhao 已提交
524
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
525

526
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
527 528

)DOC");
529 530 531
  }
};

532 533
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
534
  void Make() override {
535 536 537 538 539 540
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
541
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
542
    AddComment(R"DOC(
K
kexinzhao 已提交
543
ELU Activation Operator.
K
Kexin Zhao 已提交
544 545 546 547

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

548
$$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
K
Kexin Zhao 已提交
549 550

)DOC");
551 552 553
  }
};

Z
zhupengyang 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
template <typename T>
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("elu_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput("X", this->Input("X"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

W
wangzhen38 已提交
570 571 572 573 574 575 576 577 578
class LogitOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Logit operator");
    AddOutput("Out", "Output of Logit operator");
    AddAttr<float>("eps",
                   "(float, default 1e-6f) the epsilon for input clamp bound")
        .SetDefault(1e-6f);
    AddComment(R"DOC(
579
Logit Operator.
W
wangzhen38 已提交
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602

this function is defined as follow:
$ logit=ln\left ( {\frac {x} {1-x}} \right ) $

)DOC");
  }
};

template <typename T>
class LogitGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("logit_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
    AddAttr<float>("alpha", "The alpha value of CELU").SetDefault(1.0f);
    AddComment(R"DOC(
CELU Activation Operator.

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1704.07483.

$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$

)DOC");
  }
};

625 626
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
627
  void Make() override {
Z
zhupengyang 已提交
628 629 630 631 632 633 634 635
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
636
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
637
    AddComment(R"DOC(
K
kexinzhao 已提交
638
Relu6 Activation Operator.
K
Kexin Zhao 已提交
639

640
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
641 642

)DOC");
643 644 645
  }
};

646 647
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
648
  void Make() override {
649
    AddInput("X", "Input of Pow operator");
650 651 652 653 654
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
655
    AddOutput("Out", "Output of Pow operator");
656
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
657
    AddComment(R"DOC(
K
kexinzhao 已提交
658
Pow Activation Operator.
K
Kexin Zhao 已提交
659

660
$$out = x^{factor}$$
K
Kexin Zhao 已提交
661 662

)DOC");
663 664 665 666 667
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
668
  void Make() override {
669 670
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
671
             " A Tensor with type float32, float64.");
672 673 674
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
675 676
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
677
    AddComment(R"DOC(
K
kexinzhao 已提交
678
STanh Activation Operator.
K
Kexin Zhao 已提交
679

Y
Yan Chunwei 已提交
680
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
681 682

)DOC");
Q
qijun 已提交
683 684 685
  }
};

686 687
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
688
  void Make() override {
689
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
690
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
691 692
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
693
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
694
    AddComment(R"DOC(
Y
yuyang18 已提交
695
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
696

Y
yuyang18 已提交
697
..  math::
K
Kexin Zhao 已提交
698

Y
yuyang18 已提交
699
    out = \begin{cases}
Y
yuyang18 已提交
700
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
701 702
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
703
)DOC");
704 705 706
  }
};

707 708
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
709
  void Make() override {
710 711 712 713 714
    AddInput("X", "An N-D Tensor with data type float32, float64. ");
    AddOutput("Out", "A Tensor with the same shape as input. ");
    AddAttr<float>("slope",
                   "The slope of the linear approximation of sigmoid. Its "
                   "value MUST BE positive. Default is 0.2. ")
715
        .SetDefault(0.2f);
716 717 718
    AddAttr<float>(
        "offset",
        "The offset of the linear approximation of sigmoid. Default is 0.5. ")
719
        .SetDefault(0.5f);
720
    AddComment(R"DOC(
K
kexinzhao 已提交
721
HardSigmoid Activation Operator.
722

723
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
724
which is much faster than sigmoid.
725

726
$$out = \max(0, \min(1, slope * x + offset))$$
727

K
Kexin Zhao 已提交
728
)DOC");
729 730 731
  }
};

A
Abhinav Arora 已提交
732 733
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
734
  void Make() override {
A
Abhinav Arora 已提交
735
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
736
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
737 738 739 740
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

741
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
742 743 744 745 746

)DOC");
  }
};

747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

789
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
790 791 792 793 794 795 796 797 798

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
799
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
M
minghaoBD 已提交
800
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
D
dzhwinter 已提交
801
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
R
ronnywang 已提交
802
REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc);
D
dzhwinter 已提交
803 804 805 806
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
807
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
808 809 810
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
J
joejiong 已提交
811
REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
D
dzhwinter 已提交
812
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
813 814
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
X
xiaoting 已提交
815 816 817
REGISTER_ACTIVATION_OP_MAKER(Acosh, AcoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Asinh, AsinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Atanh, AtanhDoc);
D
dzhwinter 已提交
818 819 820
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
J
joejiong 已提交
821
REGISTER_ACTIVATION_OP_MAKER(Log2, Log2Doc);
J
joejiong 已提交
822
REGISTER_ACTIVATION_OP_MAKER(Log10, Log10Doc);
823
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
D
dzhwinter 已提交
824 825 826
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

827
template <ActBwdOpFwdDeps kDepValue>
828 829 830 831 832
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
833 834
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
835
      if (ctx->HasOutput("DX")) {
836 837 838
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
839
      if (ctx->HasOutput("DDOut")) {
840 841 842
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
843
    }
844 845
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
846
      if (ctx->HasOutput("DOut")) {
847 848 849
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
850 851 852 853
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
854 855 856 857
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
874 875
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
876 877 878 879 880
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
881 882
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
883
      if (ctx->HasOutput("DDOut")) {
884 885 886
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
887 888 889 890 891 892 893 894 895 896
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

897 898 899 900 901 902
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
903 904
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
905 906 907 908 909 910 911 912 913
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
914 915
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
template <typename T>
class SigmoidDoubleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
template <typename T>
class SigmoidTripleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_triple_grad");
    // Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    // D_OutNew, D_DOut, D_DDx               // output
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->Input("DDX"));
    // input3: dout
    op->SetInput("DOut", this->Input("DOut"));
    // input4: d_ddout
    op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
    // input5: d_dout_new
    op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
    op->SetAttrMap(this->Attrs());

    // output: d_dOut, d_OutNew, d_ddx
    op->SetOutput("D_OutNew", this->InputGrad("Out"));
    op->SetOutput("D_DOut", this->InputGrad("DOut"));
    op->SetOutput("D_DDx", this->InputGrad("DDX"));
  }
};

989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("tanh_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
template <typename T>
class TanhTripleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("tanh_triple_grad");
    // Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    // D_OutNew, D_DOut, D_DDx               // output
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->Input("DDX"));
    // input3: dout
    op->SetInput("DOut", this->Input("DOut"));
    // input4: d_ddout
    op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
    // input5: d_dout_new
    op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
    op->SetAttrMap(this->Attrs());

    // output: d_dOut, d_OutNew, d_ddx
    op->SetOutput("D_OutNew", this->InputGrad("Out"));
    op->SetOutput("D_DOut", this->InputGrad("DOut"));
    op->SetOutput("D_DDx", this->InputGrad("DDX"));
  }
};
1037 1038
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
H
hong 已提交
1039 1040
template <typename T>
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
1041
 public:
H
hong 已提交
1042
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1043 1044

 protected:
1045
  void Apply(GradOpPtr<T> op) const override {
1046 1047
    op->SetType("relu_grad_grad");
    // input1: Out
H
hong 已提交
1048
    op->SetInput("Out", this->Input("Out"));
Q
qingqing01 已提交
1049
    // input2: ddx
H
hong 已提交
1050 1051
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
1052
    // output: ddy
H
hong 已提交
1053
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1054 1055 1056
  }
};

1057 1058
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
H
hong 已提交
1059
template <typename T>
1060
class LeakyReluDoubleGradMaker
H
hong 已提交
1061
    : public ::paddle::framework::SingleGradOpMaker<T> {
1062
 public:
H
hong 已提交
1063
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1064 1065

 protected:
1066
  void Apply(GradOpPtr<T> op) const override {
1067
    op->SetType("leaky_relu_grad_grad");
1068 1069
    // input1: X
    op->SetInput("X", this->Input("X"));
1070
    // X@GRAD@GRAD: ddx
H
hong 已提交
1071 1072
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
1073
    // Out@GRAD@GRAD: ddy
H
hong 已提交
1074
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1075 1076 1077
  }
};

D
Double_V 已提交
1078 1079 1080 1081 1082 1083 1084 1085
// elu grad: dx=dy if y>0 else alpha*dy*x.exp()
// elu gradgrad: ddx=ddy if y>0 else alpha*ddy*x.exp()
template <typename T>
class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
1086
  void Apply(GradOpPtr<T> op) const override {
D
Double_V 已提交
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
    op->SetType("elu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
// celu grad: dx=dy if y>0 else dy*(x/alpha).exp()
// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha
template <typename T>
class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("celu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

L
lvmengsi 已提交
1124 1125
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
H
hong 已提交
1126 1127
template <typename T>
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
L
lvmengsi 已提交
1128
 public:
H
hong 已提交
1129
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
L
lvmengsi 已提交
1130 1131

 protected:
1132
  void Apply(GradOpPtr<T> op) const override {
L
lvmengsi 已提交
1133
    op->SetType("sqrt_grad_grad");
H
hong 已提交
1134 1135 1136 1137 1138 1139
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
L
lvmengsi 已提交
1140 1141 1142
  }
};

W
whs 已提交
1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx
template <typename T>
class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("rsqrt_grad_grad");
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1162 1163
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
H
hong 已提交
1164 1165
template <typename T>
class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
1166
 public:
H
hong 已提交
1167
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1168 1169

 protected:
1170
  void Apply(GradOpPtr<T> op) const override {
1171
    op->SetType("square_grad_grad");
H
hong 已提交
1172
    op->SetInput("X", this->Input("X"));
1173
    // Out@GRAD: dy
H
hong 已提交
1174
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
1175
    // X@GRAD@GRAD: ddx
H
hong 已提交
1176
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
1177

H
hong 已提交
1178
    op->SetAttrMap(this->Attrs());
1179 1180

    // X@GRAD: dx
H
hong 已提交
1181
    op->SetOutput("DX", this->InputGrad("X"));
1182
    // Out@GRAD@GRAD: ddy
H
hong 已提交
1183
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1184 1185 1186
  }
};

1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
// log Grad: dx = dout / x
// log Grad Grad: ddout = ddx / x; dx = -(dout / x) * (ddx / x)
template <typename T>
class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("log_grad_grad");
    op->SetInput("X", this->Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // X@GRAD: dx
    op->SetOutput("DX", this->InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1209
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
1210 1211
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
1212
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
1213
                           {"DDX", "DDOut"});
1214 1215
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
1216

W
wangzhen38 已提交
1217 1218
class LogitOp : public framework::OperatorWithKernel {
 public:
1219 1220
  LogitOp(const std::string& type,
          const framework::VariableNameMap& inputs,
W
wangzhen38 已提交
1221 1222 1223 1224 1225
          const framework::VariableNameMap& outputs,
          const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
1226 1227
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
W
wangzhen38 已提交
1228 1229
                      platform::errors::InvalidArgument(
                          "Input(%s) of LogitOp should not be null.", "X"));
1230 1231
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
                      true,
W
wangzhen38 已提交
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
                      platform::errors::InvalidArgument(
                          "Output(%s) of LogitOp should not be null.", "Out"));

    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::LibraryType library{framework::LibraryType::kPlain};
1243
    phi::DataLayout layout = phi::DataLayout::kAnyLayout;
W
wangzhen38 已提交
1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
  }
};

class LogitGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(
1256 1257
        ctx->HasInput(framework::GradVarName("Out")),
        true,
W
wangzhen38 已提交
1258 1259
        platform::errors::InvalidArgument(
            "Input(%s) of LogitGradOp should not be null.", "DOut"));
1260 1261
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
W
wangzhen38 已提交
1262 1263 1264
                      platform::errors::InvalidArgument(
                          "Input(%s) of LogitGradOp should not be null.", "X"));
    PADDLE_ENFORCE_EQ(
1265 1266
        ctx->HasOutput(framework::GradVarName("X")),
        true,
W
wangzhen38 已提交
1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277
        platform::errors::InvalidArgument(
            "Output(%s) of LogitGradOp should not be null.", "DX"));
    auto x_grad_name = framework::GradVarName("X");
    ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ x_grad_name);
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::LibraryType library{framework::LibraryType::kPlain};
1278
    phi::DataLayout layout = phi::DataLayout::kAnyLayout;
W
wangzhen38 已提交
1279 1280 1281 1282 1283
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
  }
};

H
hong 已提交
1284 1285
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
1286
 public:
H
hong 已提交
1287
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1288 1289

 protected:
1290
  void Apply(GradOpPtr<T> op) const override {
1291
    op->SetType("pow_grad");
H
hong 已提交
1292 1293 1294 1295 1296
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetInput("FactorTensor", this->Input("FactorTensor"));
    op->SetAttrMap(this->Attrs());
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
1315
      const std::string& var_name,
1316
      const phi::DenseTensor& tensor,
1317 1318 1319 1320
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
1321 1322
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
1343
      const std::string& var_name,
1344
      const phi::DenseTensor& tensor,
1345 1346 1347 1348
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
1349 1350
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
1351 1352
  }
};
1353
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
1354 1355 1356 1357
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
1358
namespace plat = paddle::platform;
1359

1360 1361
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
1362 1363 1364
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
1365
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
1366 1367 1368 1369
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
1370
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
1371 1372 1373 1374
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
1375
                    ops::ActivationGradOpInplaceInferer);
1376

L
Leo Chen 已提交
1377 1378 1379 1380 1381 1382 1383 1384 1385 1386
#define REGISTER_ACTIVATION_CPU_KERNEL(                                     \
    act_type, op_name, functor, grad_functor)                               \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type,                                                             \
      ops::ActivationKernel<phi::CPUContext, ops::functor<float>>,          \
      ops::ActivationKernel<phi::CPUContext, ops::functor<double>>);        \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type##_grad,                                                      \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
1387

1388 1389
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
1390

1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401
REGISTER_ACTIVATION_OP(cos, Cos, CosFunctor, CosGradFunctor)
REGISTER_ACTIVATION_OP(tan, Tan, TanFunctor, TanGradFunctor);
REGISTER_ACTIVATION_OP(acos, Acos, AcosFunctor, AcosGradFunctor);
REGISTER_ACTIVATION_OP(sin, Sin, SinFunctor, SinGradFunctor);
REGISTER_ACTIVATION_OP(asin, Asin, AsinFunctor, AsinGradFunctor);
REGISTER_ACTIVATION_OP(atan, Atan, AtanFunctor, AtanGradFunctor);
REGISTER_ACTIVATION_OP(sinh, Sinh, SinhFunctor, SinhGradFunctor);
REGISTER_ACTIVATION_OP(cosh, Cosh, CoshFunctor, CoshGradFunctor);
REGISTER_ACTIVATION_OP(asinh, Asinh, AsinhFunctor, AsinhGradFunctor);
REGISTER_ACTIVATION_OP(acosh, Acosh, AcoshFunctor, AcoshGradFunctor);
REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
1402
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
1403 1404 1405 1406
REGISTER_ACTIVATION_OP(thresholded_relu,
                       ThresholdedRelu,
                       ThresholdedReluFunctor,
                       ThresholdedReluGradFunctor);
1407
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
1408 1409 1410
REGISTER_ACTIVATION_OP(hard_shrink,
                       HardShrink,
                       HardShrinkFunctor,
Y
YuanRisheng 已提交
1411
                       HardShrinkGradFunctor);
1412 1413 1414
REGISTER_ACTIVATION_OP(softshrink,
                       SoftShrink,
                       SoftShrinkFunctor,
Y
YuanRisheng 已提交
1415
                       SoftShrinkGradFunctor);
1416 1417 1418
REGISTER_ACTIVATION_OP(tanh_shrink,
                       TanhShrink,
                       TanhShrinkFunctor,
Y
YuanRisheng 已提交
1419 1420
                       TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
1421 1422 1423 1424
REGISTER_ACTIVATION_OP(softsign,
                       Softsign,
                       SoftsignFunctor,
                       SoftsignGradFunctor);
1425 1426 1427
REGISTER_ACTIVATION_OP(hard_sigmoid,
                       HardSigmoid,
                       HardSigmoidFunctor,
Y
YuanRisheng 已提交
1428
                       HardSigmoidGradFunctor);
1429 1430 1431
REGISTER_ACTIVATION_OP(logsigmoid,
                       LogSigmoid,
                       LogSigmoidFunctor,
Y
YuanRisheng 已提交
1432
                       LogSigmoidGradFunctor);
1433
REGISTER_ACTIVATION_OP(expm1, Expm1, Expm1Functor, Expm1GradFunctor);
1434 1435 1436
REGISTER_ACTIVATION_OP(softplus,
                       Softplus,
                       SoftplusFunctor,
1437 1438 1439
                       SoftplusGradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
1440 1441 1442
REGISTER_ACTIVATION_OP(reciprocal,
                       Reciprocal,
                       ReciprocalFunctor,
1443 1444
                       ReciprocalGradFunctor);

1445 1446 1447
REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
1448 1449 1450
REGISTER_ACTIVATION_OP(hard_swish,
                       HardSwish,
                       HardSwishFunctor,
Y
YuanRisheng 已提交
1451 1452 1453 1454 1455
                       HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
REGISTER_ACTIVATION_OP(round, Round, RoundFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(floor, Floor, FloorFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(ceil, Ceil, CeilFunctor, ZeroGradFunctor);
1456

1457 1458 1459 1460
/* ==========================    sigmoid register  =============================
 */
// 1. Register Sigmoid Operator
REGISTER_OPERATOR(
1461 1462 1463
    sigmoid,
    ops::ActivationOp,
    ops::SigmoidOpMaker,
1464 1465 1466 1467 1468 1469
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::SigmoidGradFunctor<float>>(),
1470 1471
                     ops::ActFwdInplaceInferer,
                     void>::type);
1472 1473

// 2. Register Sigmoid Grad Operator
1474 1475
REGISTER_OPERATOR(sigmoid_grad,
                  ops::ActivationOpGrad,
1476 1477
                  ops::ActivationGradOpInplaceInferer,
                  ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
1478
                  ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
1479 1480 1481 1482

// 3. Register Sigmoid DoubleGrad Operator
REGISTER_OPERATOR(
    sigmoid_grad_grad,
1483 1484 1485 1486 1487 1488 1489 1490 1491 1492
    ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer,
    ops::SigmoidTripleGradMaker<paddle::framework::OpDesc>,
    ops::SigmoidTripleGradMaker<paddle::imperative::OpBase>);

// 4. Register Sigmoid TripleGrad Operator
REGISTER_OPERATOR(sigmoid_triple_grad,
                  ops::ActivationOpTripleGrad<
                      ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
                  ops::ActivationTripleGradOpInplaceInferer);
1493 1494 1495

/* ========================================================================== */

1496 1497
/* ==========================    tanh register  ============================= */
REGISTER_OPERATOR(
1498 1499 1500 1501
    tanh,
    ops::ActivationOp,
    ops::TanhOpMaker,
    ops::ActivationOpInferVarType,
1502 1503 1504 1505 1506
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::TanhGradFunctor<float>>(),
1507 1508 1509 1510
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(tanh_grad,
                  ops::ActivationOpGrad,
1511 1512 1513 1514 1515 1516
                  ops::ActivationGradOpInplaceInferer,
                  ops::TanhDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::TanhDoubleGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(
    tanh_grad_grad,
    ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
1517 1518 1519 1520 1521 1522 1523 1524
    ops::ActivationDoubleGradOpInplaceInferer,
    ops::TanhTripleGradMaker<paddle::framework::OpDesc>,
    ops::TanhTripleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    tanh_triple_grad,
    ops::ActivationOpTripleGrad<ops::TanhTripleGradFunctor<float>::FwdDeps()>,
    ops::ActivationTripleGradOpInplaceInferer);
1525 1526 1527

/* ========================================================================== */

1528
/* ==========================    relu register  ============================= */
1529
REGISTER_OPERATOR(
1530 1531 1532 1533
    relu,
    ops::ActivationOp,
    ops::ReluOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1534 1535 1536 1537
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1538
    ops::ActFwdInplaceInferer);
1539 1540
REGISTER_OPERATOR(relu_grad,
                  ops::ActivationOpGrad,
1541
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1542 1543
                  ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
1544 1545
REGISTER_OPERATOR(
    relu_grad_grad,
1546
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
1547
    ops::ActivationDoubleGradOpInplaceInferer);
1548

1549
/* ========================================================================== */
1550

1551
/* ======================== leaky relu register  ============================ */
1552
REGISTER_OPERATOR(
1553 1554 1555
    leaky_relu,
    ops::ActivationOp,
    ops::LeakyReluOpMaker,
1556
    ops::ActivationOpInferVarType,
H
hong 已提交
1557 1558 1559 1560
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1561
    ops::ActFwdInplaceInferer);
1562 1563
REGISTER_OPERATOR(leaky_relu_grad,
                  ops::ActivationOpGrad,
1564
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1565 1566
                  ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
1567 1568
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
1569
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
1570
    ops::ActivationDoubleGradOpInplaceInferer);
1571 1572 1573

/* ========================================================================== */

D
Double_V 已提交
1574
/* ========================    elu  register     ============================ */
1575 1576 1577
REGISTER_OPERATOR(elu,
                  ops::ActivationOp,
                  ops::ELUOpMaker,
Z
zhupengyang 已提交
1578 1579 1580 1581
                  ops::ActivationOpInferVarType,
                  ops::ELUGradOpMaker<paddle::framework::OpDesc>,
                  ops::ELUGradOpMaker<paddle::imperative::OpBase>,
                  ops::ActFwdInplaceInferer);
1582 1583
REGISTER_OPERATOR(elu_grad,
                  ops::ActivationOpGrad,
1584
                  ops::ActivationGradOpInplaceInferer,
D
Double_V 已提交
1585 1586 1587 1588 1589
                  ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    elu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
1590
    ops::ActivationDoubleGradOpInplaceInferer);
D
Double_V 已提交
1591 1592 1593

/* ========================================================================== */

W
wangzhen38 已提交
1594 1595
/* ========================    logit  register     ============================
 */
1596 1597 1598
REGISTER_OPERATOR(logit,
                  ops::LogitOp,
                  ops::LogitOpMaker,
W
wangzhen38 已提交
1599 1600 1601
                  ops::LogitGradOpMaker<paddle::framework::OpDesc>,
                  ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
1602

W
wangzhen38 已提交
1603 1604
/* ========================================================================== */

1605 1606 1607
/* ========================    celu  register     ============================
 */
REGISTER_OPERATOR(
1608 1609 1610 1611
    celu,
    ops::ActivationOp,
    ops::CELUOpMaker,
    ops::ActivationOpInferVarType,
1612 1613 1614 1615 1616
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1617 1618
REGISTER_OPERATOR(celu_grad,
                  ops::ActivationOpGrad,
1619 1620 1621 1622 1623 1624 1625 1626 1627 1628
                  ops::ActivationGradOpInplaceInferer,
                  ops::CELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::CELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    celu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

L
lvmengsi 已提交
1629 1630
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
1631 1632 1633 1634
    sqrt,
    ops::ActivationOp,
    ops::SqrtOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1635 1636 1637 1638
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1639
    ops::ActFwdInplaceInferer);
1640 1641
REGISTER_OPERATOR(sqrt_grad,
                  ops::ActivationOpGrad,
1642
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1643 1644
                  ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
L
lvmengsi 已提交
1645 1646
REGISTER_OPERATOR(
    sqrt_grad_grad,
1647
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
1648
    ops::ActivationDoubleGradOpInplaceInferer);
1649

L
lvmengsi 已提交
1650 1651
/* ========================================================================== */

W
whs 已提交
1652 1653 1654
/* ===========================   rsqrt register  =============================
 */
REGISTER_OPERATOR(
1655 1656 1657 1658
    rsqrt,
    ops::ActivationOp,
    ops::RsqrtOpMaker,
    ops::ActivationOpInferVarType,
W
whs 已提交
1659 1660 1661 1662 1663
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1664 1665
REGISTER_OPERATOR(rsqrt_grad,
                  ops::ActivationOpGrad,
W
whs 已提交
1666 1667 1668 1669 1670 1671 1672 1673 1674 1675
                  ops::ActivationGradOpInplaceInferer,
                  ops::RsqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::RsqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    rsqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

1676 1677
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
1678 1679 1680
    square,
    ops::ActivationOp,
    ops::SquareOpMaker,
1681
    ops::ActivationOpInferVarType,
H
hong 已提交
1682 1683 1684 1685
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1686
    ops::ActFwdInplaceInferer);
1687 1688
REGISTER_OPERATOR(square_grad,
                  ops::ActivationOpGrad,
1689
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1690 1691
                  ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
1692 1693
REGISTER_OPERATOR(
    square_grad_grad,
1694
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
1695
    ops::ActivationDoubleGradOpInplaceInferer);
1696 1697

/* ========================================================================== */
1698 1699 1700 1701

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
1702 1703 1704 1705
    pow,
    ops::PowOp,
    ops::PowOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1706 1707
    ops::PowGradOpMaker<paddle::framework::OpDesc>,
    ops::PowGradOpMaker<paddle::imperative::OpBase>,
1708
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
1709 1710 1711 1712
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(pow_grad,
                  ops::PowOpGrad,
1713
                  ops::ActivationGradOpInplaceInferer);
1714 1715
/* ========================================================================== */

1716 1717
/* ==========================  Log register ==================================*/
REGISTER_OPERATOR(
1718 1719 1720 1721
    log,
    ops::ActivationOp,
    ops::LogOpMaker,
    ops::ActivationOpInferVarType,
1722 1723 1724 1725 1726
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1727 1728
REGISTER_OPERATOR(log_grad,
                  ops::ActivationOpGrad,
1729 1730 1731 1732 1733 1734 1735 1736 1737
                  ops::ActivationGradOpInplaceInferer,
                  ops::LogDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LogDoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    log_grad_grad,
    ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

1757 1758
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
1759 1760
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
1761 1762 1763 1764 1765 1766 1767
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
1768 1769
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
1770
        false));
1771

1772
/* ========================================================================== */