activation_op.cc 61.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/op_version_registry.h"
24
#include "paddle/fluid/operators/common_infer_shape_functions.h"
25
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
26
#include "paddle/phi/backends/dynload/port.h"
Q
qijun 已提交
27

A
Adam 已提交
28 29
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
using paddle::framework::Tensor;

35 36
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
37 38
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
39 40
}

41 42 43 44 45 46 47 48 49 50 51 52 53 54
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
55
  }
D
dzhwinter 已提交
56

H
hong 已提交
57 58
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
59
 public:
H
hong 已提交
60
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
61 62

 protected:
63
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
64 65 66 67
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
68

A
Adam 已提交
69 70
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
71 72
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
73
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
74
      op->SetInput("X", this->Input("X"));  // x
75 76 77 78
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
79
      op->SetInput("Out", this->Output("Out"));  // out
80
    }
D
dzhwinter 已提交
81
  }
82
};
D
dzhwinter 已提交
83

84 85 86 87
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
88
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
89
  auto data_type = oper.IndicateVarDataType(ctx, name);
90 91 92 93 94 95 96 97 98 99
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
100
#ifdef PADDLE_WITH_MKLDNN
101
  if (library == framework::LibraryType::kPlain &&
102
      oper.CanMKLDNNBeUsed(ctx, data_type)) {
103
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
104
    layout = framework::DataLayout::kMKLDNN;
105 106
  }
#endif
107
  return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
108 109
}

Q
qijun 已提交
110 111 112 113
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

114
  void InferShape(framework::InferShapeContext* ctx) const override {
115
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
116
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
117
  }
118

119
 protected:
120 121 122 123
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
J
Jacek Czaja 已提交
124 125

  framework::OpKernelType GetKernelTypeForVar(
126 127
      const std::string& var_name,
      const Tensor& tensor,
128
      const framework::OpKernelType& expected_kernel_type) const override {
J
Jacek Czaja 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141
#ifdef PADDLE_WITH_MKLDNN
    // When activation is first oneDNN op (there was some non oneDNN op
    // previously)
    // then we also need to rotate shape NHWC -> NCWH
    if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
        (tensor.layout() != framework::DataLayout::kMKLDNN) &&
        paddle::platform::MKLDNNDeviceContext::tls()
                .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
      return framework::OpKernelType(expected_kernel_type.data_type_,
                                     tensor.place(),
                                     framework::DataLayout::kNHWC);
    }
#endif
142 143
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
J
Jacek Czaja 已提交
144
  }
Q
qijun 已提交
145 146
};

C
chengduo 已提交
147 148 149
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
150
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
151
      const override {
152 153
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
154 155 156
  }
};

Q
qijun 已提交
157 158 159 160
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

161
  void InferShape(framework::InferShapeContext* ctx) const override {
162 163 164
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
165
  }
166

167
 protected:
168 169
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
170
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
171
  }
Q
qijun 已提交
172 173
};

D
dzhwinter 已提交
174
UNUSED constexpr char SigmoidDoc[] = R"DOC(
175
Sigmoid Activation Operator
K
Kexin Zhao 已提交
176

177
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
178

D
dzhwinter 已提交
179
)DOC";
Q
qijun 已提交
180

M
minghaoBD 已提交
181 182 183 184 185 186
UNUSED constexpr char SiluDoc[] = R"DOC(
Silu Activation Operator

$$out = x * \\frac{1}{1 + e^{-x}}$$
)DOC";

D
dzhwinter 已提交
187
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
188
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
189

190
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
191

D
dzhwinter 已提交
192
)DOC";
193

D
dzhwinter 已提交
194
UNUSED constexpr char ExpDoc[] = R"DOC(
195
Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as the base.
K
Kexin Zhao 已提交
196

197
$$out = e^x$$
K
Kexin Zhao 已提交
198

D
dzhwinter 已提交
199
)DOC";
Q
qijun 已提交
200

R
ronnywang 已提交
201 202 203 204 205 206 207
UNUSED constexpr char Expm1Doc[] = R"DOC(
Expm1 Operator. Computes expm1 of x element-wise with a natural number :math:`e` as the base.

$$out = e^x - 1$$

)DOC";

D
dzhwinter 已提交
208
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
209
Relu Activation Operator.
K
Kexin Zhao 已提交
210

211
$$out = \max(x, 0)$$
K
Kexin Zhao 已提交
212

D
dzhwinter 已提交
213
)DOC";
K
Kexin Zhao 已提交
214

D
dzhwinter 已提交
215
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
216
Tanh Activation Operator.
K
Kexin Zhao 已提交
217

Q
update  
qiaolongfei 已提交
218
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
219

D
dzhwinter 已提交
220
)DOC";
221

D
dzhwinter 已提交
222
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
223
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
224

Y
Yan Chunwei 已提交
225
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
226

D
dzhwinter 已提交
227
)DOC";
K
Kexin Zhao 已提交
228

D
dzhwinter 已提交
229
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
230
Sqrt Activation Operator.
K
Kexin Zhao 已提交
231

N
Noel 已提交
232
$$out=\\sqrt{x}=x^{1/2}$$
233

234 235
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
236

D
dzhwinter 已提交
237
)DOC";
238

Z
zhoukunsheng 已提交
239 240 241 242 243
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

244
$$out = \\frac{1}{\\sqrt{x}}$$
Z
zhoukunsheng 已提交
245 246 247

)DOC";

D
dzhwinter 已提交
248
UNUSED constexpr char CeilDoc[] = R"DOC(
249
Ceil Operator. Computes ceil of x element-wise.
D
dzhwinter 已提交
250

N
Noel 已提交
251
$$out = \\lceil x \\rceil$$
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
)DOC";
D
dzhwinter 已提交
254

D
dzhwinter 已提交
255
UNUSED constexpr char FloorDoc[] = R"DOC(
256
Floor Activation Operator. Computes floor of x element-wise.
D
dzhwinter 已提交
257

N
Noel 已提交
258
$$out = \\lfloor x \\rfloor$$
D
dzhwinter 已提交
259

D
dzhwinter 已提交
260
)DOC";
D
dzhwinter 已提交
261

D
dzhwinter 已提交
262
UNUSED constexpr char CosDoc[] = R"DOC(
263
Cosine Operator. Computes cosine of x element-wise.
C
add cos  
chengduoZH 已提交
264

Y
Yang Zhang 已提交
265 266
Input range is `(-inf, inf)` and output range is `[-1,1]`.

267
$$out = cos(x)$$
C
add cos  
chengduoZH 已提交
268

D
dzhwinter 已提交
269
)DOC";
C
add cos  
chengduoZH 已提交
270

J
joejiong 已提交
271 272 273 274 275 276 277 278 279
UNUSED constexpr char TanDoc[] = R"DOC(
Tangent Operator. Computes tangent of x element-wise.

Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.

$$out = tan(x)$$

)DOC";

D
dzhwinter 已提交
280
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
281 282
Sine Activation Operator.

283
$$out = sin(x)$$
C
add sin  
chengduoZH 已提交
284

D
dzhwinter 已提交
285
)DOC";
C
add sin  
chengduoZH 已提交
286

287 288 289 290 291 292 293 294 295 296 297 298 299 300
UNUSED constexpr char SinhDoc[] = R"DOC(
Sinh Activation Operator.

$$out = sinh(x)$$

)DOC";

UNUSED constexpr char CoshDoc[] = R"DOC(
Cosh Activation Operator.

$$out = cosh(x)$$

)DOC";

X
xiaoting 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
UNUSED constexpr char AsinhDoc[] = R"DOC(
Asinh Activation Operator.

$$out = asinh(x)$$

)DOC";

UNUSED constexpr char AcoshDoc[] = R"DOC(
Acosh Activation Operator.

$$out = acosh(x)$$

)DOC";

UNUSED constexpr char AtanhDoc[] = R"DOC(
Atanh Activation Operator.

$$out = atanh(x)$$

)DOC";

D
dzhwinter 已提交
322
UNUSED constexpr char RoundDoc[] = R"DOC(
323
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
324

N
Noel 已提交
325
.. code-block:: text
326 327 328 329 330 331 332 333

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
334

D
dzhwinter 已提交
335
)DOC";
D
dzhwinter 已提交
336

D
dzhwinter 已提交
337
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
338
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
339

340
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
341

D
dzhwinter 已提交
342
)DOC";
343

D
dzhwinter 已提交
344
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
345
Log Activation Operator.
K
Kexin Zhao 已提交
346

347
$$out = \ln(x)$$
K
Kexin Zhao 已提交
348 349 350

Natural logarithm of x.

D
dzhwinter 已提交
351 352
)DOC";

J
joejiong 已提交
353 354 355 356 357 358 359 360 361
UNUSED constexpr char Log2Doc[] = R"DOC(
Log2 Activation Operator.

$$out = \log_2x$$

logarithm of x base to 2.

)DOC";

J
joejiong 已提交
362 363 364 365 366 367 368 369 370
UNUSED constexpr char Log10Doc[] = R"DOC(
Log10 Activation Operator.

$$out = \log_10_x$$

logarithm of x base to 10.

)DOC";

371 372 373 374 375 376 377 378 379
UNUSED constexpr char Log1pDoc[] = R"DOC(
Log Activation Operator.

$out = \ln(x+1)$

Natural logarithm of x.

)DOC";

D
dzhwinter 已提交
380
UNUSED constexpr char SquareDoc[] = R"DOC(
381
The OP square each elements of the inputs.
D
dzhwinter 已提交
382

383
$$out = x^2$$
384

D
dzhwinter 已提交
385 386
)DOC";

D
dzhwinter 已提交
387
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
388 389
Softsign Activation Operator.

390
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
391 392 393

)DOC";

T
tink2123 已提交
394 395 396 397 398 399
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
400
Arccosine Operator.
401

T
tink2123 已提交
402
$$out = \cos^{-1}(x)$$
403

T
tink2123 已提交
404 405 406
)DOC");
  }
};
407

T
tink2123 已提交
408 409 410
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
411 412 413
    AddInput("X",
             "Input of asin operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
T
tink2123 已提交
414 415
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
416
Arcsine Operator.
417

T
tink2123 已提交
418
$$out = \sin^{-1}(x)$$
419

T
tink2123 已提交
420 421 422
)DOC");
  }
};
423

T
tink2123 已提交
424 425 426
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
427 428 429
    AddInput("X",
             "Input of atan operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
T
tink2123 已提交
430 431
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
432
Arctangent Operator.
433

434
$$out = \tan^{-1}(x)$$
435

T
tink2123 已提交
436 437 438
)DOC");
  }
};
439

D
dzhwinter 已提交
440
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
441
 public:
Y
Yu Yang 已提交
442
  void Make() override {
W
Wilber 已提交
443 444 445 446 447 448 449 450
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
K
Kexin Zhao 已提交
451
    AddComment(R"DOC(
D
dzhwinter 已提交
452
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
453

W
Wilber 已提交
454
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
455 456

)DOC");
457 458 459
  }
};

460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "Input of Softplus operator, an N-D Tensor, with data type "
             "float32, float64 or float16.");
    AddOutput(
        "Out",
        "Output of Softplus operator, a Tensor with shape same as input.");
    AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
    AddAttr<float>("threshold", "The value of threshold for Softplus.")
        .SetDefault(20.0f);
    AddComment(R"DOC(
:strong:`Softplus Activation Operator`

..  math::
    out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
    \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.

)DOC");
  }
};

D
dzhwinter 已提交
483
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
484
 public:
Y
Yu Yang 已提交
485
  void Make() override {
D
dzhwinter 已提交
486 487 488
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
489
    AddComment(R"DOC(
490 491 492
:strong:`Softshrink Activation Operator`

..  math::
493
    out = \begin{cases}
494 495 496 497
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
498 499

)DOC");
K
kexinzhao 已提交
500 501 502
  }
};

D
dzhwinter 已提交
503
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
504
 public:
Y
Yu Yang 已提交
505
  void Make() override {
D
dzhwinter 已提交
506 507
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
508 509
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
510
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
511
    AddComment(R"DOC(
Y
yuyang18 已提交
512
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
513

Y
yuyang18 已提交
514 515 516 517 518 519
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
520 521

)DOC");
522 523 524
  }
};

525 526
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
527
  void Make() override {
528 529 530 531 532 533
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
534 535 536 537
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
538
    AddComment(R"DOC(
K
kexinzhao 已提交
539
BRelu Activation Operator.
K
Kexin Zhao 已提交
540

541
$$out = \min(\max(x, t_{min}), t_{max})$$
K
Kexin Zhao 已提交
542 543

)DOC");
544 545 546 547 548
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
549
  void Make() override {
550
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
551
    AddOutput("Out", "Output of SoftRelu operator");
552 553
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
554
    AddComment(R"DOC(
K
kexinzhao 已提交
555
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
556

557
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
558 559

)DOC");
560 561 562
  }
};

563 564
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
565
  void Make() override {
566 567 568 569 570 571
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
572
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
573
    AddComment(R"DOC(
K
kexinzhao 已提交
574
ELU Activation Operator.
K
Kexin Zhao 已提交
575 576 577 578

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

579
$$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
K
Kexin Zhao 已提交
580 581

)DOC");
582 583 584
  }
};

Z
zhupengyang 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
template <typename T>
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("elu_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetInput("Out", this->Output("Out"));
    op->SetInput("X", this->Input("X"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
  }
};

W
wangzhen38 已提交
601 602 603 604 605 606 607 608 609
class LogitOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Logit operator");
    AddOutput("Out", "Output of Logit operator");
    AddAttr<float>("eps",
                   "(float, default 1e-6f) the epsilon for input clamp bound")
        .SetDefault(1e-6f);
    AddComment(R"DOC(
610
Logit Operator.
W
wangzhen38 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633

this function is defined as follow:
$ logit=ln\left ( {\frac {x} {1-x}} \right ) $

)DOC");
  }
};

template <typename T>
class LogitGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("logit_grad");
    grad_op->SetInput("X", this->Input("X"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
    AddAttr<float>("alpha", "The alpha value of CELU").SetDefault(1.0f);
    AddComment(R"DOC(
CELU Activation Operator.

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1704.07483.

$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$

)DOC");
  }
};

656 657
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
658
  void Make() override {
Z
zhupengyang 已提交
659 660 661 662 663 664 665 666
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
667
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
668
    AddComment(R"DOC(
K
kexinzhao 已提交
669
Relu6 Activation Operator.
K
Kexin Zhao 已提交
670

671
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
672 673

)DOC");
674 675 676
  }
};

677 678
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
679
  void Make() override {
680
    AddInput("X", "Input of Pow operator");
681 682 683 684 685
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
686
    AddOutput("Out", "Output of Pow operator");
687
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
688
    AddComment(R"DOC(
K
kexinzhao 已提交
689
Pow Activation Operator.
K
Kexin Zhao 已提交
690

691
$$out = x^{factor}$$
K
Kexin Zhao 已提交
692 693

)DOC");
694 695 696 697 698
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
699
  void Make() override {
700 701
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
702
             " A Tensor with type float32, float64.");
703 704 705
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
706 707
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
708
    AddComment(R"DOC(
K
kexinzhao 已提交
709
STanh Activation Operator.
K
Kexin Zhao 已提交
710

Y
Yan Chunwei 已提交
711
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
712 713

)DOC");
Q
qijun 已提交
714 715 716
  }
};

717 718
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
719
  void Make() override {
720
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
721
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
722 723
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
724
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
725
    AddComment(R"DOC(
Y
yuyang18 已提交
726
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
727

Y
yuyang18 已提交
728
..  math::
K
Kexin Zhao 已提交
729

Y
yuyang18 已提交
730
    out = \begin{cases}
Y
yuyang18 已提交
731
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
732 733
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
734
)DOC");
735 736 737
  }
};

738 739
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
740
  void Make() override {
741 742 743 744 745
    AddInput("X", "An N-D Tensor with data type float32, float64. ");
    AddOutput("Out", "A Tensor with the same shape as input. ");
    AddAttr<float>("slope",
                   "The slope of the linear approximation of sigmoid. Its "
                   "value MUST BE positive. Default is 0.2. ")
746
        .SetDefault(0.2f);
747 748 749
    AddAttr<float>(
        "offset",
        "The offset of the linear approximation of sigmoid. Default is 0.5. ")
750
        .SetDefault(0.5f);
751
    AddComment(R"DOC(
K
kexinzhao 已提交
752
HardSigmoid Activation Operator.
753

754
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
755
which is much faster than sigmoid.
756

757
$$out = \max(0, \min(1, slope * x + offset))$$
758

K
Kexin Zhao 已提交
759
)DOC");
760 761 762
  }
};

A
Abhinav Arora 已提交
763 764
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
765
  void Make() override {
A
Abhinav Arora 已提交
766
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
767
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
768 769 770 771
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

772
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
773 774 775 776 777

)DOC");
  }
};

778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

820
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
821 822 823 824 825 826 827 828 829

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
830
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
M
minghaoBD 已提交
831
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
D
dzhwinter 已提交
832 833
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
R
ronnywang 已提交
834
REGISTER_ACTIVATION_OP_MAKER(Expm1, Expm1Doc);
D
dzhwinter 已提交
835 836 837 838
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
839
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
840 841 842
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
J
joejiong 已提交
843
REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
D
dzhwinter 已提交
844
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
845 846
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
X
xiaoting 已提交
847 848 849
REGISTER_ACTIVATION_OP_MAKER(Acosh, AcoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Asinh, AsinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Atanh, AtanhDoc);
D
dzhwinter 已提交
850 851 852
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
J
joejiong 已提交
853
REGISTER_ACTIVATION_OP_MAKER(Log2, Log2Doc);
J
joejiong 已提交
854
REGISTER_ACTIVATION_OP_MAKER(Log10, Log10Doc);
855
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
D
dzhwinter 已提交
856 857 858
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

859
template <ActBwdOpFwdDeps kDepValue>
860 861 862 863 864
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
865 866
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
867
      if (ctx->HasOutput("DX")) {
868 869 870
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
871
      if (ctx->HasOutput("DDOut")) {
872 873 874
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
875
    }
876 877
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
878
      if (ctx->HasOutput("DOut")) {
879 880 881
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
882 883 884 885
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
886 887 888 889
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
906 907
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
908 909 910 911 912
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
913 914
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
915
      if (ctx->HasOutput("DDOut")) {
916 917 918
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
919 920 921 922 923 924 925 926 927 928
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

929 930 931 932 933 934
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
935 936
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
937 938 939 940 941 942 943 944 945
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
946 947
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990
template <typename T>
class SigmoidDoubleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020
template <typename T>
class SigmoidTripleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_triple_grad");
    // Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    // D_OutNew, D_DOut, D_DDx               // output
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->Input("DDX"));
    // input3: dout
    op->SetInput("DOut", this->Input("DOut"));
    // input4: d_ddout
    op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
    // input5: d_dout_new
    op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
    op->SetAttrMap(this->Attrs());

    // output: d_dOut, d_OutNew, d_ddx
    op->SetOutput("D_OutNew", this->InputGrad("Out"));
    op->SetOutput("D_DOut", this->InputGrad("DOut"));
    op->SetOutput("D_DDx", this->InputGrad("DDX"));
  }
};

1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("tanh_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068
template <typename T>
class TanhTripleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("tanh_triple_grad");
    // Out, DDX, DOut, D_DDOut, D_DOut_New   // input
    // D_OutNew, D_DOut, D_DDx               // output
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->Input("DDX"));
    // input3: dout
    op->SetInput("DOut", this->Input("DOut"));
    // input4: d_ddout
    op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
    // input5: d_dout_new
    op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
    op->SetAttrMap(this->Attrs());

    // output: d_dOut, d_OutNew, d_ddx
    op->SetOutput("D_OutNew", this->InputGrad("Out"));
    op->SetOutput("D_DOut", this->InputGrad("DOut"));
    op->SetOutput("D_DDx", this->InputGrad("DDX"));
  }
};
1069 1070
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
H
hong 已提交
1071 1072
template <typename T>
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
1073
 public:
H
hong 已提交
1074
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1075 1076

 protected:
1077
  void Apply(GradOpPtr<T> op) const override {
1078 1079
    op->SetType("relu_grad_grad");
    // input1: Out
H
hong 已提交
1080
    op->SetInput("Out", this->Input("Out"));
Q
qingqing01 已提交
1081
    // input2: ddx
H
hong 已提交
1082 1083
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
1084
    // output: ddy
H
hong 已提交
1085
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1086 1087 1088
  }
};

1089 1090
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
H
hong 已提交
1091
template <typename T>
1092
class LeakyReluDoubleGradMaker
H
hong 已提交
1093
    : public ::paddle::framework::SingleGradOpMaker<T> {
1094
 public:
H
hong 已提交
1095
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1096 1097

 protected:
1098
  void Apply(GradOpPtr<T> op) const override {
1099
    op->SetType("leaky_relu_grad_grad");
1100 1101
    // input1: X
    op->SetInput("X", this->Input("X"));
1102
    // X@GRAD@GRAD: ddx
H
hong 已提交
1103 1104
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
1105
    // Out@GRAD@GRAD: ddy
H
hong 已提交
1106
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1107 1108 1109
  }
};

D
Double_V 已提交
1110 1111 1112 1113 1114 1115 1116 1117
// elu grad: dx=dy if y>0 else alpha*dy*x.exp()
// elu gradgrad: ddx=ddy if y>0 else alpha*ddy*x.exp()
template <typename T>
class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
1118
  void Apply(GradOpPtr<T> op) const override {
D
Double_V 已提交
1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
    op->SetType("elu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
// celu grad: dx=dy if y>0 else dy*(x/alpha).exp()
// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha
template <typename T>
class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("celu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

L
lvmengsi 已提交
1156 1157
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
H
hong 已提交
1158 1159
template <typename T>
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
L
lvmengsi 已提交
1160
 public:
H
hong 已提交
1161
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
L
lvmengsi 已提交
1162 1163

 protected:
1164
  void Apply(GradOpPtr<T> op) const override {
L
lvmengsi 已提交
1165
    op->SetType("sqrt_grad_grad");
H
hong 已提交
1166 1167 1168 1169 1170 1171
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
L
lvmengsi 已提交
1172 1173 1174
  }
};

W
whs 已提交
1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx
template <typename T>
class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("rsqrt_grad_grad");
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1194 1195
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
H
hong 已提交
1196 1197
template <typename T>
class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
1198
 public:
H
hong 已提交
1199
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1200 1201

 protected:
1202
  void Apply(GradOpPtr<T> op) const override {
1203
    op->SetType("square_grad_grad");
H
hong 已提交
1204
    op->SetInput("X", this->Input("X"));
1205
    // Out@GRAD: dy
H
hong 已提交
1206
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
1207
    // X@GRAD@GRAD: ddx
H
hong 已提交
1208
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
1209

H
hong 已提交
1210
    op->SetAttrMap(this->Attrs());
1211 1212

    // X@GRAD: dx
H
hong 已提交
1213
    op->SetOutput("DX", this->InputGrad("X"));
1214
    // Out@GRAD@GRAD: ddy
H
hong 已提交
1215
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
1216 1217 1218
  }
};

1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
// log Grad: dx = dout / x
// log Grad Grad: ddout = ddx / x; dx = -(dout / x) * (ddx / x)
template <typename T>
class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("log_grad_grad");
    op->SetInput("X", this->Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // X@GRAD: dx
    op->SetOutput("DX", this->InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

1241
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
1242 1243
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
1244
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
1245
                           {"DDX", "DDOut"});
1246 1247
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
1248

W
wangzhen38 已提交
1249 1250
class LogitOp : public framework::OperatorWithKernel {
 public:
1251 1252
  LogitOp(const std::string& type,
          const framework::VariableNameMap& inputs,
W
wangzhen38 已提交
1253 1254 1255 1256 1257
          const framework::VariableNameMap& outputs,
          const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
1258 1259
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
W
wangzhen38 已提交
1260 1261
                      platform::errors::InvalidArgument(
                          "Input(%s) of LogitOp should not be null.", "X"));
1262 1263
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
                      true,
W
wangzhen38 已提交
1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287
                      platform::errors::InvalidArgument(
                          "Output(%s) of LogitOp should not be null.", "Out"));

    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
  }
};

class LogitGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(
1288 1289
        ctx->HasInput(framework::GradVarName("Out")),
        true,
W
wangzhen38 已提交
1290 1291
        platform::errors::InvalidArgument(
            "Input(%s) of LogitGradOp should not be null.", "DOut"));
1292 1293
    PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
                      true,
W
wangzhen38 已提交
1294 1295 1296
                      platform::errors::InvalidArgument(
                          "Input(%s) of LogitGradOp should not be null.", "X"));
    PADDLE_ENFORCE_EQ(
1297 1298
        ctx->HasOutput(framework::GradVarName("X")),
        true,
W
wangzhen38 已提交
1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315
        platform::errors::InvalidArgument(
            "Output(%s) of LogitGradOp should not be null.", "DX"));
    auto x_grad_name = framework::GradVarName("X");
    ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ x_grad_name);
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    framework::LibraryType library{framework::LibraryType::kPlain};
    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
    return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
  }
};

H
hong 已提交
1316 1317
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
1318
 public:
H
hong 已提交
1319
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
1320 1321

 protected:
1322
  void Apply(GradOpPtr<T> op) const override {
1323
    op->SetType("pow_grad");
H
hong 已提交
1324 1325 1326 1327 1328
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetInput("FactorTensor", this->Input("FactorTensor"));
    op->SetAttrMap(this->Attrs());
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
1347 1348
      const std::string& var_name,
      const Tensor& tensor,
1349 1350 1351 1352
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
1353 1354
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
1375 1376
      const std::string& var_name,
      const Tensor& tensor,
1377 1378 1379 1380
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
1381 1382
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
1383 1384
  }
};
1385
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
1386 1387 1388 1389
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
1390
namespace plat = paddle::platform;
1391

1392 1393
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
1394 1395 1396
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
1397
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
1398 1399 1400 1401
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
1402
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
1403 1404 1405 1406
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
1407
                    ops::ActivationGradOpInplaceInferer);
1408

L
Leo Chen 已提交
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418
#define REGISTER_ACTIVATION_CPU_KERNEL(                                     \
    act_type, op_name, functor, grad_functor)                               \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type,                                                             \
      ops::ActivationKernel<phi::CPUContext, ops::functor<float>>,          \
      ops::ActivationKernel<phi::CPUContext, ops::functor<double>>);        \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type##_grad,                                                      \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
1419

1420 1421
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
1422

1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433
REGISTER_ACTIVATION_OP(cos, Cos, CosFunctor, CosGradFunctor)
REGISTER_ACTIVATION_OP(tan, Tan, TanFunctor, TanGradFunctor);
REGISTER_ACTIVATION_OP(acos, Acos, AcosFunctor, AcosGradFunctor);
REGISTER_ACTIVATION_OP(sin, Sin, SinFunctor, SinGradFunctor);
REGISTER_ACTIVATION_OP(asin, Asin, AsinFunctor, AsinGradFunctor);
REGISTER_ACTIVATION_OP(atan, Atan, AtanFunctor, AtanGradFunctor);
REGISTER_ACTIVATION_OP(sinh, Sinh, SinhFunctor, SinhGradFunctor);
REGISTER_ACTIVATION_OP(cosh, Cosh, CoshFunctor, CoshGradFunctor);
REGISTER_ACTIVATION_OP(asinh, Asinh, AsinhFunctor, AsinhGradFunctor);
REGISTER_ACTIVATION_OP(acosh, Acosh, AcoshFunctor, AcoshGradFunctor);
REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
1434
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
1435 1436 1437 1438
REGISTER_ACTIVATION_OP(thresholded_relu,
                       ThresholdedRelu,
                       ThresholdedReluFunctor,
                       ThresholdedReluGradFunctor);
1439
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
1440 1441 1442
REGISTER_ACTIVATION_OP(hard_shrink,
                       HardShrink,
                       HardShrinkFunctor,
Y
YuanRisheng 已提交
1443
                       HardShrinkGradFunctor);
1444 1445 1446
REGISTER_ACTIVATION_OP(softshrink,
                       SoftShrink,
                       SoftShrinkFunctor,
Y
YuanRisheng 已提交
1447
                       SoftShrinkGradFunctor);
1448 1449 1450
REGISTER_ACTIVATION_OP(tanh_shrink,
                       TanhShrink,
                       TanhShrinkFunctor,
Y
YuanRisheng 已提交
1451 1452
                       TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
1453 1454 1455 1456
REGISTER_ACTIVATION_OP(softsign,
                       Softsign,
                       SoftsignFunctor,
                       SoftsignGradFunctor);
1457 1458 1459
REGISTER_ACTIVATION_OP(hard_sigmoid,
                       HardSigmoid,
                       HardSigmoidFunctor,
Y
YuanRisheng 已提交
1460
                       HardSigmoidGradFunctor);
1461 1462 1463
REGISTER_ACTIVATION_OP(logsigmoid,
                       LogSigmoid,
                       LogSigmoidFunctor,
Y
YuanRisheng 已提交
1464
                       LogSigmoidGradFunctor);
1465
REGISTER_ACTIVATION_OP(expm1, Expm1, Expm1Functor, Expm1GradFunctor);
1466 1467 1468
REGISTER_ACTIVATION_OP(softplus,
                       Softplus,
                       SoftplusFunctor,
1469 1470 1471
                       SoftplusGradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
1472 1473 1474
REGISTER_ACTIVATION_OP(reciprocal,
                       Reciprocal,
                       ReciprocalFunctor,
1475 1476
                       ReciprocalGradFunctor);

1477 1478 1479
REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
1480 1481 1482
REGISTER_ACTIVATION_OP(hard_swish,
                       HardSwish,
                       HardSwishFunctor,
Y
YuanRisheng 已提交
1483 1484 1485 1486 1487
                       HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
REGISTER_ACTIVATION_OP(round, Round, RoundFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(floor, Floor, FloorFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(ceil, Ceil, CeilFunctor, ZeroGradFunctor);
1488

1489 1490 1491 1492
/* ==========================    sigmoid register  =============================
 */
// 1. Register Sigmoid Operator
REGISTER_OPERATOR(
1493 1494 1495
    sigmoid,
    ops::ActivationOp,
    ops::SigmoidOpMaker,
1496 1497 1498 1499 1500 1501
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::SigmoidGradFunctor<float>>(),
1502 1503
                     ops::ActFwdInplaceInferer,
                     void>::type);
1504 1505

// 2. Register Sigmoid Grad Operator
1506 1507
REGISTER_OPERATOR(sigmoid_grad,
                  ops::ActivationOpGrad,
1508 1509
                  ops::ActivationGradOpInplaceInferer,
                  ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
1510
                  ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
1511 1512 1513 1514

// 3. Register Sigmoid DoubleGrad Operator
REGISTER_OPERATOR(
    sigmoid_grad_grad,
1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
    ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer,
    ops::SigmoidTripleGradMaker<paddle::framework::OpDesc>,
    ops::SigmoidTripleGradMaker<paddle::imperative::OpBase>);

// 4. Register Sigmoid TripleGrad Operator
REGISTER_OPERATOR(sigmoid_triple_grad,
                  ops::ActivationOpTripleGrad<
                      ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
                  ops::ActivationTripleGradOpInplaceInferer);
1525 1526 1527

/* ========================================================================== */

1528 1529
/* ==========================    tanh register  ============================= */
REGISTER_OPERATOR(
1530 1531 1532 1533
    tanh,
    ops::ActivationOp,
    ops::TanhOpMaker,
    ops::ActivationOpInferVarType,
1534 1535 1536 1537 1538
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::TanhGradFunctor<float>>(),
1539 1540 1541 1542
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(tanh_grad,
                  ops::ActivationOpGrad,
1543 1544 1545 1546 1547 1548
                  ops::ActivationGradOpInplaceInferer,
                  ops::TanhDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::TanhDoubleGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(
    tanh_grad_grad,
    ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
1549 1550 1551 1552 1553 1554 1555 1556
    ops::ActivationDoubleGradOpInplaceInferer,
    ops::TanhTripleGradMaker<paddle::framework::OpDesc>,
    ops::TanhTripleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    tanh_triple_grad,
    ops::ActivationOpTripleGrad<ops::TanhTripleGradFunctor<float>::FwdDeps()>,
    ops::ActivationTripleGradOpInplaceInferer);
1557 1558 1559

/* ========================================================================== */

1560
/* ==========================    relu register  ============================= */
1561
REGISTER_OPERATOR(
1562 1563 1564 1565
    relu,
    ops::ActivationOp,
    ops::ReluOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1566 1567 1568 1569
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1570
    ops::ActFwdInplaceInferer);
1571 1572
REGISTER_OPERATOR(relu_grad,
                  ops::ActivationOpGrad,
1573
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1574 1575
                  ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
1576 1577
REGISTER_OPERATOR(
    relu_grad_grad,
1578
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
1579
    ops::ActivationDoubleGradOpInplaceInferer);
1580

1581
/* ========================================================================== */
1582

1583
/* ======================== leaky relu register  ============================ */
1584
REGISTER_OPERATOR(
1585 1586 1587
    leaky_relu,
    ops::ActivationOp,
    ops::LeakyReluOpMaker,
1588
    ops::ActivationOpInferVarType,
H
hong 已提交
1589 1590 1591 1592
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1593
    ops::ActFwdInplaceInferer);
1594 1595
REGISTER_OPERATOR(leaky_relu_grad,
                  ops::ActivationOpGrad,
1596
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1597 1598
                  ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
1599 1600
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
1601
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
1602
    ops::ActivationDoubleGradOpInplaceInferer);
1603 1604 1605

/* ========================================================================== */

D
Double_V 已提交
1606
/* ========================    elu  register     ============================ */
1607 1608 1609
REGISTER_OPERATOR(elu,
                  ops::ActivationOp,
                  ops::ELUOpMaker,
Z
zhupengyang 已提交
1610 1611 1612 1613
                  ops::ActivationOpInferVarType,
                  ops::ELUGradOpMaker<paddle::framework::OpDesc>,
                  ops::ELUGradOpMaker<paddle::imperative::OpBase>,
                  ops::ActFwdInplaceInferer);
1614 1615
REGISTER_OPERATOR(elu_grad,
                  ops::ActivationOpGrad,
1616
                  ops::ActivationGradOpInplaceInferer,
D
Double_V 已提交
1617 1618 1619 1620 1621
                  ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    elu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
1622
    ops::ActivationDoubleGradOpInplaceInferer);
D
Double_V 已提交
1623 1624 1625

/* ========================================================================== */

W
wangzhen38 已提交
1626 1627
/* ========================    logit  register     ============================
 */
1628 1629 1630
REGISTER_OPERATOR(logit,
                  ops::LogitOp,
                  ops::LogitOpMaker,
W
wangzhen38 已提交
1631 1632 1633
                  ops::LogitGradOpMaker<paddle::framework::OpDesc>,
                  ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
1634

W
wangzhen38 已提交
1635 1636
/* ========================================================================== */

1637 1638 1639
/* ========================    celu  register     ============================
 */
REGISTER_OPERATOR(
1640 1641 1642 1643
    celu,
    ops::ActivationOp,
    ops::CELUOpMaker,
    ops::ActivationOpInferVarType,
1644 1645 1646 1647 1648
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1649 1650
REGISTER_OPERATOR(celu_grad,
                  ops::ActivationOpGrad,
1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
                  ops::ActivationGradOpInplaceInferer,
                  ops::CELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::CELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    celu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

L
lvmengsi 已提交
1661 1662
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
1663 1664 1665 1666
    sqrt,
    ops::ActivationOp,
    ops::SqrtOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1667 1668 1669 1670
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1671
    ops::ActFwdInplaceInferer);
1672 1673
REGISTER_OPERATOR(sqrt_grad,
                  ops::ActivationOpGrad,
1674
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1675 1676
                  ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
L
lvmengsi 已提交
1677 1678
REGISTER_OPERATOR(
    sqrt_grad_grad,
1679
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
1680
    ops::ActivationDoubleGradOpInplaceInferer);
1681

L
lvmengsi 已提交
1682 1683
/* ========================================================================== */

W
whs 已提交
1684 1685 1686
/* ===========================   rsqrt register  =============================
 */
REGISTER_OPERATOR(
1687 1688 1689 1690
    rsqrt,
    ops::ActivationOp,
    ops::RsqrtOpMaker,
    ops::ActivationOpInferVarType,
W
whs 已提交
1691 1692 1693 1694 1695
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1696 1697
REGISTER_OPERATOR(rsqrt_grad,
                  ops::ActivationOpGrad,
W
whs 已提交
1698 1699 1700 1701 1702 1703 1704 1705 1706 1707
                  ops::ActivationGradOpInplaceInferer,
                  ops::RsqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::RsqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    rsqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

/* ========================================================================== */

1708 1709
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
1710 1711 1712
    square,
    ops::ActivationOp,
    ops::SquareOpMaker,
1713
    ops::ActivationOpInferVarType,
H
hong 已提交
1714 1715 1716 1717
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1718
    ops::ActFwdInplaceInferer);
1719 1720
REGISTER_OPERATOR(square_grad,
                  ops::ActivationOpGrad,
1721
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1722 1723
                  ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
1724 1725
REGISTER_OPERATOR(
    square_grad_grad,
1726
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
1727
    ops::ActivationDoubleGradOpInplaceInferer);
1728 1729

/* ========================================================================== */
1730 1731 1732 1733

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
1734 1735 1736 1737
    pow,
    ops::PowOp,
    ops::PowOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1738 1739
    ops::PowGradOpMaker<paddle::framework::OpDesc>,
    ops::PowGradOpMaker<paddle::imperative::OpBase>,
1740
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
1741 1742 1743 1744
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(pow_grad,
                  ops::PowOpGrad,
1745
                  ops::ActivationGradOpInplaceInferer);
1746 1747 1748 1749
/* ========================================================================== */

/* ==========================   exp register  ============================ */
REGISTER_OPERATOR(
1750 1751 1752 1753
    exp,
    ops::ActivationOp,
    ops::ExpOpMaker,
    ops::ActivationOpInferVarType,
1754 1755 1756 1757 1758
    ops::ActivationGradOpMaker<ops::ExpGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ExpGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::ExpGradFunctor<float>>(),
1759 1760 1761 1762
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(exp_grad,
                  ops::ActivationOpGrad,
1763
                  ops::ActivationGradOpInplaceInferer);
1764

1765 1766
/* ==========================  Log register ==================================*/
REGISTER_OPERATOR(
1767 1768 1769 1770
    log,
    ops::ActivationOp,
    ops::LogOpMaker,
    ops::ActivationOpInferVarType,
1771 1772 1773 1774 1775
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
1776 1777
REGISTER_OPERATOR(log_grad,
                  ops::ActivationOpGrad,
1778 1779 1780 1781 1782 1783 1784 1785 1786
                  ops::ActivationGradOpInplaceInferer,
                  ops::LogDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LogDoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    log_grad_grad,
    ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

1806 1807
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
1808 1809
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
1810 1811 1812 1813 1814 1815 1816
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
1817 1818
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
1819
        false));
1820

1821
/* ========================================================================== */