activation_op.cc 34.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
44 45 46 47 48
      AddInput("X", "Input of " #OP_NAME                                     \
                    " operator, an N-D Tensor, with data type float32, "     \
                    "float64 or float16.");                                  \
      AddOutput("Out", "Output of " #OP_NAME                                 \
                       " operator, a Tensor with shape same as input.");     \
49 50 51 52 53 54 55 56 57 58 59 60 61 62
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
63
  }
D
dzhwinter 已提交
64

65 66 67 68 69 70 71 72 73 74 75 76 77
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
78 79 80 81
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
82 83 84 85 86 87 88 89 90
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
91
  }
92
};
D
dzhwinter 已提交
93

94 95 96 97
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
98
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
99 100 101 102 103 104 105 106 107 108
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
109 110 111 112 113
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
114
    layout = framework::DataLayout::kMKLDNN;
115 116 117
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
118 119
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
120 121
}

Q
qijun 已提交
122 123 124 125
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

126
  void InferShape(framework::InferShapeContext* ctx) const override {
127
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
128
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
129
  }
130

131
 protected:
132 133 134 135
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
136 137
};

C
chengduo 已提交
138 139 140 141 142 143
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
144 145 146
  }
};

Q
qijun 已提交
147 148 149 150
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

151
  void InferShape(framework::InferShapeContext* ctx) const override {
152 153 154
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
155
  }
156

157
 protected:
158 159
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
160
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
161
  }
Q
qijun 已提交
162 163
};

D
dzhwinter 已提交
164
UNUSED constexpr char SigmoidDoc[] = R"DOC(
165
Sigmoid Activation Operator
K
Kexin Zhao 已提交
166

167
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
Q
qijun 已提交
170

D
dzhwinter 已提交
171
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
172
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

D
dzhwinter 已提交
178
UNUSED constexpr char ExpDoc[] = R"DOC(
179
Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as the base.
K
Kexin Zhao 已提交
180

F
fengjiayi 已提交
181
$out = e^x$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
Q
qijun 已提交
184

D
dzhwinter 已提交
185
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
186
Relu Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = \max(x, 0)$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
K
Kexin Zhao 已提交
191

C
Clementine 已提交
192 193 194 195 196 197 198
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
199
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
200
Tanh Activation Operator.
K
Kexin Zhao 已提交
201

Q
update  
qiaolongfei 已提交
202
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
203

D
dzhwinter 已提交
204
)DOC";
205

D
dzhwinter 已提交
206
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
207
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
208

Y
Yan Chunwei 已提交
209
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
K
Kexin Zhao 已提交
212

D
dzhwinter 已提交
213
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
214
Sqrt Activation Operator.
K
Kexin Zhao 已提交
215

216
.. math:: out=\sqrt x=x^{1/2}
217

218 219
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
220

D
dzhwinter 已提交
221
)DOC";
222

Z
zhoukunsheng 已提交
223 224 225 226 227 228 229 230 231
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
232
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
233
Abs Activation Operator.
K
Kexin Zhao 已提交
234

F
fengjiayi 已提交
235
$out = |x|$
K
Kexin Zhao 已提交
236

D
dzhwinter 已提交
237
)DOC";
238

D
dzhwinter 已提交
239
UNUSED constexpr char CeilDoc[] = R"DOC(
240
Ceil Operator. Computes ceil of x element-wise.
D
dzhwinter 已提交
241

242
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
243

D
dzhwinter 已提交
244
)DOC";
D
dzhwinter 已提交
245

D
dzhwinter 已提交
246
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Floor Activation Operator.

249
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char CosDoc[] = R"DOC(
254
Cosine Operator. Computes cosine of x element-wise.
C
add cos  
chengduoZH 已提交
255 256 257

$out = cos(x)$

D
dzhwinter 已提交
258
)DOC";
C
add cos  
chengduoZH 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261 262 263 264
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
265
)DOC";
C
add sin  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char RoundDoc[] = R"DOC(
268
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
269

270 271 272 273 274 275 276 277 278
.. code-block:: python

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
279

D
dzhwinter 已提交
280
)DOC";
D
dzhwinter 已提交
281

D
dzhwinter 已提交
282
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
283
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
284

285
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
286

D
dzhwinter 已提交
287
)DOC";
288

D
dzhwinter 已提交
289
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
290
Log Activation Operator.
K
Kexin Zhao 已提交
291

F
fengjiayi 已提交
292
$out = \ln(x)$
K
Kexin Zhao 已提交
293 294 295

Natural logarithm of x.

D
dzhwinter 已提交
296 297
)DOC";

D
dzhwinter 已提交
298
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
299 300 301
Square Activation Operator.

$out = x^2$
302

D
dzhwinter 已提交
303 304
)DOC";

D
dzhwinter 已提交
305
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
306 307 308 309 310 311
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
312
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
313 314
Softsign Activation Operator.

315
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
316 317 318

)DOC";

T
tink2123 已提交
319 320 321 322 323 324
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
325 326
Arccosine Activation Operator.

T
tink2123 已提交
327
$$out = \cos^{-1}(x)$$
328

T
tink2123 已提交
329 330 331
)DOC");
  }
};
332

T
tink2123 已提交
333 334 335 336 337 338
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
339 340
Arcsine Activation Operator.

T
tink2123 已提交
341
$$out = \sin^{-1}(x)$$
342

T
tink2123 已提交
343 344 345
)DOC");
  }
};
346

T
tink2123 已提交
347 348 349 350 351 352
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
353 354
Arctanh Activation Operator.

T
tink2123 已提交
355
$$out = \tanh^{-1}(x)$$
356

T
tink2123 已提交
357 358 359
)DOC");
  }
};
360

D
dzhwinter 已提交
361
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
362
 public:
Y
Yu Yang 已提交
363
  void Make() override {
W
Wilber 已提交
364 365 366 367 368 369 370 371
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
A
Adam 已提交
372 373 374 375 376 377 378
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
379
    AddComment(R"DOC(
D
dzhwinter 已提交
380
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
381

W
Wilber 已提交
382
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
383 384

)DOC");
385 386 387
  }
};

D
dzhwinter 已提交
388
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
389
 public:
Y
Yu Yang 已提交
390
  void Make() override {
D
dzhwinter 已提交
391 392 393
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
394
    AddComment(R"DOC(
395 396 397
:strong:`Softshrink Activation Operator`

..  math::
398
    out = \begin{cases}
399 400 401 402
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
403 404

)DOC");
K
kexinzhao 已提交
405 406 407
  }
};

D
dzhwinter 已提交
408
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
409
 public:
Y
Yu Yang 已提交
410
  void Make() override {
D
dzhwinter 已提交
411 412
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
413 414
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
415
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
416
    AddComment(R"DOC(
Y
yuyang18 已提交
417
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
418

Y
yuyang18 已提交
419 420 421 422 423 424
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
425 426

)DOC");
427 428 429
  }
};

430 431
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
432
  void Make() override {
433 434 435 436 437 438
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
439 440 441 442
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
443
    AddComment(R"DOC(
K
kexinzhao 已提交
444
BRelu Activation Operator.
K
Kexin Zhao 已提交
445

446
$out = \min(\max(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
447 448

)DOC");
449 450 451 452 453
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
454
  void Make() override {
455
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
456
    AddOutput("Out", "Output of SoftRelu operator");
457 458
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
459
    AddComment(R"DOC(
K
kexinzhao 已提交
460
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
461

T
tensor-tang 已提交
462
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
463 464

)DOC");
465 466 467
  }
};

468 469
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
470
  void Make() override {
471 472 473 474 475 476
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
477
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
478
    AddComment(R"DOC(
K
kexinzhao 已提交
479
ELU Activation Operator.
K
Kexin Zhao 已提交
480 481 482 483

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
484
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
485 486

)DOC");
487 488 489
  }
};

490 491
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
492
  void Make() override {
493
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
494
    AddOutput("Out", "Output of Relu6 operator");
495 496
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
497
    AddComment(R"DOC(
K
kexinzhao 已提交
498
Relu6 Activation Operator.
K
Kexin Zhao 已提交
499

F
fengjiayi 已提交
500
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
501 502

)DOC");
503 504 505
  }
};

506 507
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
508
  void Make() override {
509
    AddInput("X", "Input of Pow operator");
510 511 512 513 514
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
515
    AddOutput("Out", "Output of Pow operator");
516
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
517
    AddComment(R"DOC(
K
kexinzhao 已提交
518
Pow Activation Operator.
K
Kexin Zhao 已提交
519

F
fengjiayi 已提交
520
$out = x^{factor}$
K
Kexin Zhao 已提交
521 522

)DOC");
523 524 525 526 527
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
528
  void Make() override {
529 530 531 532 533 534
    AddInput("X",
             "Input of STanh operator."
             " A LoDTensor or Tensor with type float32, float64.");
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
535 536
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
537
    AddComment(R"DOC(
K
kexinzhao 已提交
538
STanh Activation Operator.
K
Kexin Zhao 已提交
539

Y
Yan Chunwei 已提交
540
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
541 542

)DOC");
Q
qijun 已提交
543 544 545
  }
};

546 547
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
548
  void Make() override {
549
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
550
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
551 552
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
553
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
554
    AddComment(R"DOC(
Y
yuyang18 已提交
555
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
556

Y
yuyang18 已提交
557
..  math::
K
Kexin Zhao 已提交
558

Y
yuyang18 已提交
559
    out = \begin{cases}
Y
yuyang18 已提交
560
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
561 562
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
563
)DOC");
564 565 566
  }
};

567 568
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
569
  void Make() override {
570
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
571
    AddOutput("Out", "Output of HardSigmoid operator");
572 573 574 575
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
576
    AddComment(R"DOC(
K
kexinzhao 已提交
577
HardSigmoid Activation Operator.
578

579
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
580
which is much faster than sigmoid.
581

F
fengjiayi 已提交
582
$out = \max(0, \min(1, slope * x + shift))$
583 584

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
585
The default slope and shift are set according to the above reference.
586 587
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
588
)DOC");
589 590 591
  }
};

A
Abhinav Arora 已提交
592 593
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
594
  void Make() override {
A
Abhinav Arora 已提交
595
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
596
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
597 598 599 600
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

601
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
602 603 604 605 606

)DOC");
  }
};

H
huangjun12 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
633 634 635 636
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
637
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
638 639 640
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
641
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
642 643 644 645 646 647 648 649 650 651 652 653
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

654
template <ActBwdOpFwdDeps kDepValue>
655 656 657 658 659
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
660
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
661
      if (ctx->HasOutput("DX")) {
662 663 664
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
665
      if (ctx->HasOutput("DDOut")) {
666 667 668
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
669
    }
670
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
671
      if (ctx->HasOutput("DOut")) {
672 673 674
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
703 704 705
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
706 707 708 709 710 711 712 713 714 715
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

716 717 718 719 720 721 722 723 724 725 726 727 728 729
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
730
    // input2: ddx
731 732
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
733
    // output: ddy
734 735 736 737 738
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

739 740 741 742 743 744 745 746 747 748 749
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
750 751
    // input1: Out
    op->SetInput("Out", Input("Out"));
752 753 754 755 756 757 758 759 760
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

808 809 810
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
811 812
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
813

814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
883 884 885 886
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
887
namespace plat = paddle::platform;
888

889 890 891 892 893 894 895 896
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
897 898
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
899 900 901

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
902 903 904 905 906 907 908 909 910 911
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
912
                                ops::grad_functor<double>>);
913

914 915
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
916

917
/* ==========================    relu register  ============================= */
918 919 920 921 922
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
923
                  ops::ActivationGradOpInplaceInference,
924
                  ops::ReluDoubleGradMaker);
925 926
REGISTER_OPERATOR(
    relu_grad_grad,
927 928
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
929 930 931 932 933 934 935 936 937 938 939

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
940
/* ========================================================================== */
941

942
/* ======================== leaky relu register  ============================ */
943 944 945 946 947 948
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
949
                  ops::ActivationGradOpInplaceInference,
950
                  ops::LeakyReluDoubleGradMaker);
951 952
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
953 954
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
955

956 957 958 959 960 961 962 963 964 965
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
966 967
/* ========================================================================== */

L
lvmengsi 已提交
968 969 970 971 972 973
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
974
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
975 976 977
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
978 979 980
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
981 982 983 984 985 986 987 988 989 990
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

991 992 993 994 995 996 997
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
998
                  ops::ActivationGradOpInplaceInference,
999 1000 1001
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
1002 1003
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */