activation_op.cc 52.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/op_version_registry.h"
24
#include "paddle/fluid/operators/common_infer_shape_functions.h"
25
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
26
#include "paddle/fluid/platform/port.h"
Q
qijun 已提交
27

A
Adam 已提交
28 29
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
using paddle::framework::Tensor;

35 36 37 38 39
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

40 41 42 43 44
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
45 46 47 48 49
      AddInput("X", "Input of " #OP_NAME                                     \
                    " operator, an N-D Tensor, with data type float32, "     \
                    "float64 or float16.");                                  \
      AddOutput("Out", "Output of " #OP_NAME                                 \
                       " operator, a Tensor with shape same as input.");     \
50 51 52 53 54 55 56 57 58
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
59
  }
D
dzhwinter 已提交
60

H
hong 已提交
61 62
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
63
 public:
H
hong 已提交
64
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
65 66

 protected:
67
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
68 69 70 71
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
72

A
Adam 已提交
73 74
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
75 76 77
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
         BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
H
hong 已提交
78
      op->SetInput("X", this->Input("X"));
79 80 81 82
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
H
hong 已提交
83
      op->SetInput("Out", this->Output("Out"));
84
    }
D
dzhwinter 已提交
85
  }
86
};
D
dzhwinter 已提交
87

88 89 90 91
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
92
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
93
  auto data_type = oper.IndicateVarDataType(ctx, name);
94 95 96 97 98 99 100 101 102 103
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
104 105 106
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
107
      oper.CanMKLDNNBeUsed(ctx, data_type)) {
108
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
109
    layout = framework::DataLayout::kMKLDNN;
110 111
  }
#endif
112
  return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
113 114
}

Q
qijun 已提交
115 116 117 118
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

119
  void InferShape(framework::InferShapeContext* ctx) const override {
120
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
121
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
122
  }
123

124
 protected:
125 126 127 128
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
129 130
};

C
chengduo 已提交
131 132 133
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
134
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
135
      const override {
136 137
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
138 139 140
  }
};

Q
qijun 已提交
141 142 143 144
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

145
  void InferShape(framework::InferShapeContext* ctx) const override {
146 147 148
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
149
  }
150

151
 protected:
152 153
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
154
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
155
  }
Q
qijun 已提交
156 157
};

D
dzhwinter 已提交
158
UNUSED constexpr char SigmoidDoc[] = R"DOC(
159
Sigmoid Activation Operator
K
Kexin Zhao 已提交
160

161
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
162

D
dzhwinter 已提交
163
)DOC";
Q
qijun 已提交
164

M
minghaoBD 已提交
165 166 167 168 169 170
UNUSED constexpr char SiluDoc[] = R"DOC(
Silu Activation Operator

$$out = x * \\frac{1}{1 + e^{-x}}$$
)DOC";

D
dzhwinter 已提交
171
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
172
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

D
dzhwinter 已提交
178
UNUSED constexpr char ExpDoc[] = R"DOC(
179
Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as the base.
K
Kexin Zhao 已提交
180

181
$$out = e^x$$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
Q
qijun 已提交
184

D
dzhwinter 已提交
185
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
186
Relu Activation Operator.
K
Kexin Zhao 已提交
187

188
$$out = \max(x, 0)$$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
K
Kexin Zhao 已提交
191

D
dzhwinter 已提交
192
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
193
Tanh Activation Operator.
K
Kexin Zhao 已提交
194

Q
update  
qiaolongfei 已提交
195
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
196

D
dzhwinter 已提交
197
)DOC";
198

D
dzhwinter 已提交
199
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
200
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
201

Y
Yan Chunwei 已提交
202
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
203

D
dzhwinter 已提交
204
)DOC";
K
Kexin Zhao 已提交
205

D
dzhwinter 已提交
206
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
207
Sqrt Activation Operator.
K
Kexin Zhao 已提交
208

N
Noel 已提交
209
$$out=\\sqrt{x}=x^{1/2}$$
210

211 212
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
213

D
dzhwinter 已提交
214
)DOC";
215

Z
zhoukunsheng 已提交
216 217 218 219 220
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

221
$$out = \\frac{1}{\\sqrt{x}}$$
Z
zhoukunsheng 已提交
222 223 224

)DOC";

D
dzhwinter 已提交
225
UNUSED constexpr char CeilDoc[] = R"DOC(
226
Ceil Operator. Computes ceil of x element-wise.
D
dzhwinter 已提交
227

N
Noel 已提交
228
$$out = \\lceil x \\rceil$$
D
dzhwinter 已提交
229

D
dzhwinter 已提交
230
)DOC";
D
dzhwinter 已提交
231

D
dzhwinter 已提交
232
UNUSED constexpr char FloorDoc[] = R"DOC(
233
Floor Activation Operator. Computes floor of x element-wise.
D
dzhwinter 已提交
234

N
Noel 已提交
235
$$out = \\lfloor x \\rfloor$$
D
dzhwinter 已提交
236

D
dzhwinter 已提交
237
)DOC";
D
dzhwinter 已提交
238

D
dzhwinter 已提交
239
UNUSED constexpr char CosDoc[] = R"DOC(
240
Cosine Operator. Computes cosine of x element-wise.
C
add cos  
chengduoZH 已提交
241

Y
Yang Zhang 已提交
242 243
Input range is `(-inf, inf)` and output range is `[-1,1]`.

244
$$out = cos(x)$$
C
add cos  
chengduoZH 已提交
245

D
dzhwinter 已提交
246
)DOC";
C
add cos  
chengduoZH 已提交
247

J
joejiong 已提交
248 249 250 251 252 253 254 255 256
UNUSED constexpr char TanDoc[] = R"DOC(
Tangent Operator. Computes tangent of x element-wise.

Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.

$$out = tan(x)$$

)DOC";

D
dzhwinter 已提交
257
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
258 259
Sine Activation Operator.

260
$$out = sin(x)$$
C
add sin  
chengduoZH 已提交
261

D
dzhwinter 已提交
262
)DOC";
C
add sin  
chengduoZH 已提交
263

264 265 266 267 268 269 270 271 272 273 274 275 276 277
UNUSED constexpr char SinhDoc[] = R"DOC(
Sinh Activation Operator.

$$out = sinh(x)$$

)DOC";

UNUSED constexpr char CoshDoc[] = R"DOC(
Cosh Activation Operator.

$$out = cosh(x)$$

)DOC";

D
dzhwinter 已提交
278
UNUSED constexpr char RoundDoc[] = R"DOC(
279
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
280

N
Noel 已提交
281
.. code-block:: text
282 283 284 285 286 287 288 289

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
290

D
dzhwinter 已提交
291
)DOC";
D
dzhwinter 已提交
292

D
dzhwinter 已提交
293
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
294
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
295

296
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
297

D
dzhwinter 已提交
298
)DOC";
299

D
dzhwinter 已提交
300
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
301
Log Activation Operator.
K
Kexin Zhao 已提交
302

303
$$out = \ln(x)$$
K
Kexin Zhao 已提交
304 305 306

Natural logarithm of x.

D
dzhwinter 已提交
307 308
)DOC";

J
joejiong 已提交
309 310 311 312 313 314 315 316 317
UNUSED constexpr char Log2Doc[] = R"DOC(
Log2 Activation Operator.

$$out = \log_2x$$

logarithm of x base to 2.

)DOC";

J
joejiong 已提交
318 319 320 321 322 323 324 325 326
UNUSED constexpr char Log10Doc[] = R"DOC(
Log10 Activation Operator.

$$out = \log_10_x$$

logarithm of x base to 10.

)DOC";

327 328 329 330 331 332 333 334 335
UNUSED constexpr char Log1pDoc[] = R"DOC(
Log Activation Operator.

$out = \ln(x+1)$

Natural logarithm of x.

)DOC";

D
dzhwinter 已提交
336
UNUSED constexpr char SquareDoc[] = R"DOC(
337
The OP square each elements of the inputs.
D
dzhwinter 已提交
338

339
$$out = x^2$$
340

D
dzhwinter 已提交
341 342
)DOC";

D
dzhwinter 已提交
343
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
344 345
Softsign Activation Operator.

346
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
347 348 349

)DOC";

T
tink2123 已提交
350 351 352 353 354 355
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
356
Arccosine Operator.
357

T
tink2123 已提交
358
$$out = \cos^{-1}(x)$$
359

T
tink2123 已提交
360 361 362
)DOC");
  }
};
363

T
tink2123 已提交
364 365 366
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
367 368 369
    AddInput("X",
             "Input of asin operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
T
tink2123 已提交
370 371
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
372
Arcsine Operator.
373

T
tink2123 已提交
374
$$out = \sin^{-1}(x)$$
375

T
tink2123 已提交
376 377 378
)DOC");
  }
};
379

T
tink2123 已提交
380 381 382
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
W
wawltor 已提交
383 384 385
    AddInput("X",
             "Input of atan operator, an N-D Tensor, with data type float32, "
             "float64 or float16.");
T
tink2123 已提交
386 387
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
388
Arctangent Operator.
389

390
$$out = \tan^{-1}(x)$$
391

T
tink2123 已提交
392 393 394
)DOC");
  }
};
395

D
dzhwinter 已提交
396
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
397
 public:
Y
Yu Yang 已提交
398
  void Make() override {
W
Wilber 已提交
399 400 401 402 403 404 405 406
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
A
Adam 已提交
407 408 409
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
K
Kexin Zhao 已提交
410
    AddComment(R"DOC(
D
dzhwinter 已提交
411
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
412

W
Wilber 已提交
413
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
414 415

)DOC");
416 417 418
  }
};

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "Input of Softplus operator, an N-D Tensor, with data type "
             "float32, float64 or float16.");
    AddOutput(
        "Out",
        "Output of Softplus operator, a Tensor with shape same as input.");
    AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
    AddAttr<float>("threshold", "The value of threshold for Softplus.")
        .SetDefault(20.0f);
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel.")
        .SetDefault(false);
    AddAttr<bool>(
        "use_cudnn",
        "(bool, default false) Only used in cudnn kernel, need install cudnn.")
        .SetDefault(false);
    AddComment(R"DOC(
:strong:`Softplus Activation Operator`

..  math::
    out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
    \text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.

)DOC");
  }
};

D
dzhwinter 已提交
449
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
450
 public:
Y
Yu Yang 已提交
451
  void Make() override {
D
dzhwinter 已提交
452 453 454
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
455
    AddComment(R"DOC(
456 457 458
:strong:`Softshrink Activation Operator`

..  math::
459
    out = \begin{cases}
460 461 462 463
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
464 465

)DOC");
K
kexinzhao 已提交
466 467 468
  }
};

D
dzhwinter 已提交
469
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
470
 public:
Y
Yu Yang 已提交
471
  void Make() override {
D
dzhwinter 已提交
472 473
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
474 475
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
476
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
477
    AddComment(R"DOC(
Y
yuyang18 已提交
478
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
479

Y
yuyang18 已提交
480 481 482 483 484 485
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
486 487

)DOC");
488 489 490
  }
};

491 492
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
493
  void Make() override {
494 495 496 497 498 499
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
500 501 502 503
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
504
    AddComment(R"DOC(
K
kexinzhao 已提交
505
BRelu Activation Operator.
K
Kexin Zhao 已提交
506

507
$$out = \min(\max(x, t_{min}), t_{max})$$
K
Kexin Zhao 已提交
508 509

)DOC");
510 511 512 513 514
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
515
  void Make() override {
516
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
517
    AddOutput("Out", "Output of SoftRelu operator");
518 519
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
520
    AddComment(R"DOC(
K
kexinzhao 已提交
521
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
522

523
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
524 525

)DOC");
526 527 528
  }
};

529 530
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
531
  void Make() override {
532 533 534 535 536 537
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
538
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
539
    AddComment(R"DOC(
K
kexinzhao 已提交
540
ELU Activation Operator.
K
Kexin Zhao 已提交
541 542 543 544

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

545
$$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
K
Kexin Zhao 已提交
546 547

)DOC");
548 549 550
  }
};

551 552
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
553
  void Make() override {
Z
zhupengyang 已提交
554 555 556 557 558 559 560 561
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
562
        .SetDefault(6.0f);
A
Adam 已提交
563 564 565
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
K
Kexin Zhao 已提交
566
    AddComment(R"DOC(
K
kexinzhao 已提交
567
Relu6 Activation Operator.
K
Kexin Zhao 已提交
568

569
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
570 571

)DOC");
572 573 574
  }
};

575 576
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
577
  void Make() override {
578
    AddInput("X", "Input of Pow operator");
579 580 581 582 583
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
584
    AddOutput("Out", "Output of Pow operator");
585
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
586
    AddComment(R"DOC(
K
kexinzhao 已提交
587
Pow Activation Operator.
K
Kexin Zhao 已提交
588

589
$$out = x^{factor}$$
K
Kexin Zhao 已提交
590 591

)DOC");
592 593 594 595 596
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
597
  void Make() override {
598 599
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
600
             " A Tensor with type float32, float64.");
601 602 603
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
604 605
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
606
    AddComment(R"DOC(
K
kexinzhao 已提交
607
STanh Activation Operator.
K
Kexin Zhao 已提交
608

Y
Yan Chunwei 已提交
609
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
610 611

)DOC");
Q
qijun 已提交
612 613 614
  }
};

615 616
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
617
  void Make() override {
618
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
619
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
620 621
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
622
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
623
    AddComment(R"DOC(
Y
yuyang18 已提交
624
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
625

Y
yuyang18 已提交
626
..  math::
K
Kexin Zhao 已提交
627

Y
yuyang18 已提交
628
    out = \begin{cases}
Y
yuyang18 已提交
629
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
630 631
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
632
)DOC");
633 634 635
  }
};

636 637
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
638
  void Make() override {
639 640 641 642 643
    AddInput("X", "An N-D Tensor with data type float32, float64. ");
    AddOutput("Out", "A Tensor with the same shape as input. ");
    AddAttr<float>("slope",
                   "The slope of the linear approximation of sigmoid. Its "
                   "value MUST BE positive. Default is 0.2. ")
644
        .SetDefault(0.2f);
645 646 647
    AddAttr<float>(
        "offset",
        "The offset of the linear approximation of sigmoid. Default is 0.5. ")
648
        .SetDefault(0.5f);
649
    AddComment(R"DOC(
K
kexinzhao 已提交
650
HardSigmoid Activation Operator.
651

652
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
653
which is much faster than sigmoid.
654

655
$$out = \max(0, \min(1, slope * x + offset))$$
656

K
Kexin Zhao 已提交
657
)DOC");
658 659 660
  }
};

A
Abhinav Arora 已提交
661 662
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
663
  void Make() override {
A
Abhinav Arora 已提交
664
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
665
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
666
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
667 668 669
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
A
Abhinav Arora 已提交
670 671 672
    AddComment(R"DOC(
Swish Activation Operator.

673
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
674 675 676 677 678

)DOC");
  }
};

H
huangjun12 已提交
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

695
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
696 697 698 699 700 701 702 703 704

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
705
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
M
minghaoBD 已提交
706
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
D
dzhwinter 已提交
707 708 709 710 711 712
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
713
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
714 715 716
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
J
joejiong 已提交
717
REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
D
dzhwinter 已提交
718
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
719 720
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
D
dzhwinter 已提交
721 722 723
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
J
joejiong 已提交
724
REGISTER_ACTIVATION_OP_MAKER(Log2, Log2Doc);
J
joejiong 已提交
725
REGISTER_ACTIVATION_OP_MAKER(Log10, Log10Doc);
726
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
D
dzhwinter 已提交
727 728 729
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

730
template <ActBwdOpFwdDeps kDepValue>
731 732 733 734 735
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
736
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
737
      if (ctx->HasOutput("DX")) {
738 739 740
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
741
      if (ctx->HasOutput("DDOut")) {
742 743 744
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
745
    }
746
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
747
      if (ctx->HasOutput("DOut")) {
748 749 750
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
779 780 781
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
782 783 784 785 786 787 788 789 790 791
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
template <typename T>
class SigmoidDoubleGradMaker
    : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("sigmoid_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("tanh_grad_grad");
    // input1: Out
    op->SetInput("Out", this->Input("Out"));
    // input2: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // output: ddy
    op->SetOutput("DOutNew", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

833 834
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
H
hong 已提交
835 836
template <typename T>
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
837
 public:
H
hong 已提交
838
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
839 840

 protected:
841
  void Apply(GradOpPtr<T> op) const override {
842 843
    op->SetType("relu_grad_grad");
    // input1: Out
H
hong 已提交
844
    op->SetInput("Out", this->Input("Out"));
Q
qingqing01 已提交
845
    // input2: ddx
H
hong 已提交
846 847
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
848
    // output: ddy
H
hong 已提交
849
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
850 851 852
  }
};

853 854
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
H
hong 已提交
855
template <typename T>
856
class LeakyReluDoubleGradMaker
H
hong 已提交
857
    : public ::paddle::framework::SingleGradOpMaker<T> {
858
 public:
H
hong 已提交
859
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
860 861

 protected:
862
  void Apply(GradOpPtr<T> op) const override {
863
    op->SetType("leaky_relu_grad_grad");
864 865
    // input1: X
    op->SetInput("X", this->Input("X"));
866
    // X@GRAD@GRAD: ddx
H
hong 已提交
867 868
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
869
    // Out@GRAD@GRAD: ddy
H
hong 已提交
870
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
871 872 873
  }
};

D
Double_V 已提交
874 875 876 877 878 879 880 881
// elu grad: dx=dy if y>0 else alpha*dy*x.exp()
// elu gradgrad: ddx=ddy if y>0 else alpha*ddy*x.exp()
template <typename T>
class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
882
  void Apply(GradOpPtr<T> op) const override {
D
Double_V 已提交
883 884 885 886 887 888 889 890 891 892 893 894 895 896
    op->SetType("elu_grad_grad");

    op->SetInput("X", this->Input("X"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());

    // Out@GRAD@GRAD: ddy
    op->SetOutput("DX", this->InputGrad("X"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

L
lvmengsi 已提交
897 898
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
H
hong 已提交
899 900
template <typename T>
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
L
lvmengsi 已提交
901
 public:
H
hong 已提交
902
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
L
lvmengsi 已提交
903 904

 protected:
905
  void Apply(GradOpPtr<T> op) const override {
L
lvmengsi 已提交
906
    op->SetType("sqrt_grad_grad");
H
hong 已提交
907 908 909 910 911 912
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
L
lvmengsi 已提交
913 914 915
  }
};

W
whs 已提交
916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * ddx
template <typename T>
class RsqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("rsqrt_grad_grad");
    op->SetInput("Out", this->Input("Out"));
    op->SetInput("DX", this->Output(framework::GradVarName("X")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(this->Attrs());
    op->SetOutput("DOut", this->InputGrad("Out"));
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

935 936
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
H
hong 已提交
937 938
template <typename T>
class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
939
 public:
H
hong 已提交
940
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
941 942

 protected:
943
  void Apply(GradOpPtr<T> op) const override {
944
    op->SetType("square_grad_grad");
H
hong 已提交
945
    op->SetInput("X", this->Input("X"));
946
    // Out@GRAD: dy
H
hong 已提交
947
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
948
    // X@GRAD@GRAD: ddx
H
hong 已提交
949
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
950

H
hong 已提交
951
    op->SetAttrMap(this->Attrs());
952 953

    // X@GRAD: dx
H
hong 已提交
954
    op->SetOutput("DX", this->InputGrad("X"));
955
    // Out@GRAD@GRAD: ddy
H
hong 已提交
956
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
957 958 959
  }
};

960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
// log Grad: dx = dout / x
// log Grad Grad: ddout = ddx / x; dx = -(dout / x) * (ddx / x)
template <typename T>
class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
 public:
  using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("log_grad_grad");
    op->SetInput("X", this->Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetAttrMap(this->Attrs());
    // X@GRAD: dx
    op->SetOutput("DX", this->InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
  }
};

982
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
983 984
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
985
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
986
                           {"DDX", "DDOut"});
987

H
hong 已提交
988 989
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
990
 public:
H
hong 已提交
991
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
992 993

 protected:
994
  void Apply(GradOpPtr<T> op) const override {
995
    op->SetType("pow_grad");
H
hong 已提交
996 997 998 999 1000
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetInput("FactorTensor", this->Input("FactorTensor"));
    op->SetAttrMap(this->Attrs());
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
1055
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
1056 1057 1058 1059
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
1060
namespace plat = paddle::platform;
1061

1062 1063 1064 1065
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
1066 1067 1068 1069
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
1070
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
1071
                       ops::ActFwdInplaceInferer, void>::type);             \
1072
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
1073
                    ops::ActivationGradOpInplaceInferer);
1074 1075 1076

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
1087
                                ops::grad_functor<double>>);
1088

1089 1090
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
1091

1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
/* ==========================    sigmoid register  =============================
 */
// 1. Register Sigmoid Operator
REGISTER_OPERATOR(
    sigmoid, ops::ActivationOp, ops::SigmoidOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SigmoidGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::SigmoidGradFunctor<float>>(),
                     ops::ActFwdInplaceInferer, void>::type);

// 2. Register Sigmoid Grad Operator
REGISTER_OPERATOR(sigmoid_grad, ops::ActivationOpGrad,
                  ops::ActivationGradOpInplaceInferer,
                  ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>)

// 3. Register Sigmoid DoubleGrad Operator
REGISTER_OPERATOR(
    sigmoid_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SigmoidGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

// Register Sigmoid/GradSigmoid Kernels
REGISTER_ACTIVATION_CPU_KERNEL(sigmoid, Sigmoid, SigmoidFunctor,
                               SigmoidGradFunctor);

// Register DoubleGrad Kernel
REGISTER_OP_CPU_KERNEL(
    sigmoid_grad_grad,
    ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
                                 ops::SigmoidGradGradFunctor<float>>,
    ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
                                 ops::SigmoidGradGradFunctor<double>>,
    ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
                                 ops::SigmoidGradGradFunctor<plat::float16>>);

/* ========================================================================== */

1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
/* ==========================    tanh register  ============================= */
REGISTER_OPERATOR(
    tanh, ops::ActivationOp, ops::TanhOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::TanhGradFunctor<float>>(),
                     ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(tanh_grad, ops::ActivationOpGrad,
                  ops::ActivationGradOpInplaceInferer,
                  ops::TanhDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::TanhDoubleGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(
    tanh_grad_grad,
    ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);
REGISTER_OP_CPU_KERNEL(
    tanh_grad_grad, ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::TanhGradGradFunctor<float>>,
    ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
                              ops::TanhGradGradFunctor<double>>,
    ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
                              ops::TanhGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1161
/* ==========================    relu register  ============================= */
1162 1163
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
H
hong 已提交
1164 1165 1166 1167
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1168
    ops::ActFwdInplaceInferer);
1169
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
1170
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1171 1172
                  ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
1173 1174
REGISTER_OPERATOR(
    relu_grad_grad,
1175
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
1176
    ops::ActivationDoubleGradOpInplaceInferer);
1177

1178
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor);
1179 1180 1181 1182 1183 1184 1185 1186 1187

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
1188
/* ========================================================================== */
1189

1190
/* ======================== leaky relu register  ============================ */
1191 1192 1193
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1194 1195 1196 1197
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1198
    ops::ActFwdInplaceInferer);
1199
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
1200
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1201 1202
                  ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
1203 1204
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
1205
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
1206
    ops::ActivationDoubleGradOpInplaceInferer);
1207

1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
1218 1219
/* ========================================================================== */

D
Double_V 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228
/* ========================    elu  register     ============================ */
REGISTER_OPERATOR(
    elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
1229
                  ops::ActivationGradOpInplaceInferer,
D
Double_V 已提交
1230 1231 1232 1233 1234
                  ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    elu_grad_grad,
    ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
1235
    ops::ActivationDoubleGradOpInplaceInferer);
D
Double_V 已提交
1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247

REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
    elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
                                            ops::ELUGradGradFunctor<float>>,
    ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
                             ops::ELUGradGradFunctor<double>>,
    ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
                             ops::ELUGradGradFunctor<plat::float16>>);

/* ========================================================================== */

L
lvmengsi 已提交
1248 1249 1250
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
H
hong 已提交
1251 1252 1253 1254
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1255
    ops::ActFwdInplaceInferer);
L
lvmengsi 已提交
1256
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
1257
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1258 1259
                  ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
L
lvmengsi 已提交
1260 1261
REGISTER_OPERATOR(
    sqrt_grad_grad,
1262
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
1263
    ops::ActivationDoubleGradOpInplaceInferer);
1264

L
lvmengsi 已提交
1265 1266 1267 1268 1269 1270 1271 1272 1273 1274
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

W
whs 已提交
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
/* ===========================   rsqrt register  =============================
 */
REGISTER_OPERATOR(
    rsqrt, ops::ActivationOp, ops::RsqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::RsqrtGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(rsqrt_grad, ops::ActivationOpGrad,
                  ops::ActivationGradOpInplaceInferer,
                  ops::RsqrtDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::RsqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
    rsqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    rsqrt_grad_grad,
    ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
                               ops::RsqrtGradGradFunctor<float>>,
    ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
                               ops::RsqrtGradGradFunctor<double>>,
    ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
                               ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1304 1305 1306 1307
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
1308 1309 1310 1311
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
1312
    ops::ActFwdInplaceInferer);
1313
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
1314
                  ops::ActivationGradOpInplaceInferer,
H
hong 已提交
1315 1316
                  ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
1317 1318
REGISTER_OPERATOR(
    square_grad_grad,
1319
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
1320
    ops::ActivationDoubleGradOpInplaceInferer);
1321

1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
REGISTER_OP_CPU_KERNEL(square,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::SquareFunctor<float>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::SquareFunctor<double>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::SquareFunctor<int>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::SquareFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
    square_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                                           ops::SquareGradFunctor<float>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::SquareGradFunctor<double>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::SquareGradFunctor<int>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::SquareGradFunctor<int64_t>>);
1340 1341 1342 1343 1344 1345 1346 1347

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
1348 1349 1350 1351 1352
                                ops::SquareGradGradFunctor<plat::float16>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<int>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<int64_t>>);
1353
/* ========================================================================== */
1354 1355 1356 1357 1358

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
H
hong 已提交
1359 1360
    ops::PowGradOpMaker<paddle::framework::OpDesc>,
    ops::PowGradOpMaker<paddle::imperative::OpBase>,
1361
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
1362
                     ops::ActFwdInplaceInferer, void>::type);
1363
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
1364
                  ops::ActivationGradOpInplaceInferer);
1365 1366 1367

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
1368 1369 1370
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<int>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<int64_t>>);
1371 1372 1373
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<int>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<int64_t>>);
/* ========================================================================== */

/* ==========================   exp register  ============================ */
REGISTER_OPERATOR(
    exp, ops::ActivationOp, ops::ExpOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::ExpGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::ExpGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    std::conditional<ops::CanInplaceAct<ops::ExpGradFunctor<float>>(),
                     ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(exp_grad, ops::ActivationOpGrad,
1389
                  ops::ActivationGradOpInplaceInferer);
1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409

REGISTER_OP_CPU_KERNEL(exp,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::ExpFunctor<float>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::ExpFunctor<double>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::ExpFunctor<int>>,
                       ops::ActivationKernel<paddle::platform::CPUDeviceContext,
                                             ops::ExpFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
    exp_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                                        ops::ExpGradFunctor<float>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::ExpGradFunctor<double>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::ExpGradFunctor<int>>,
    ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
                              ops::ExpGradFunctor<int64_t>>);
/* ========================================================================== */
1410

1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439
/* ==========================  Log register ==================================*/
REGISTER_OPERATOR(
    log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::framework::OpDesc>,
    ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
                               paddle::imperative::OpBase>,
    ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(log_grad, ops::ActivationOpGrad,
                  ops::ActivationGradOpInplaceInferer,
                  ops::LogDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::LogDoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
    log_grad_grad,
    ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(log, Log, LogFunctor, LogGradFunctor);

REGISTER_OP_CPU_KERNEL(
    log_grad_grad, ops::LogDoubleGradKernel<plat::CPUDeviceContext,
                                            ops::LogGradGradFunctor<float>>,
    ops::LogDoubleGradKernel<plat::CPUDeviceContext,
                             ops::LogGradGradFunctor<double>>,
    ops::LogDoubleGradKernel<plat::CPUDeviceContext,
                             ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */

1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

1459 1460 1461 1462 1463 1464 1465 1466 1467 1468
REGISTER_OP_VERSION(softplus)
    .AddCheckpoint(
        R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
        paddle::framework::compatible::OpVersionDesc()
            .NewAttr("beta", "The beta value of the new formula", 1.0f)
            .NewAttr("threshold", "The threshold value of the new formula",
                     20.0f));

1469
/* ========================================================================== */