activation_op.cc 34.2 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
44 45 46 47 48
      AddInput("X", "Input of " #OP_NAME                                     \
                    " operator, an N-D Tensor, with data type float32, "     \
                    "float64 or float16.");                                  \
      AddOutput("Out", "Output of " #OP_NAME                                 \
                       " operator, a Tensor with shape same as input.");     \
49 50 51 52 53 54 55 56 57 58 59 60 61 62
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
63
  }
D
dzhwinter 已提交
64

65 66 67 68 69 70 71 72 73 74 75 76 77
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
78 79 80 81
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
82 83 84 85 86 87 88 89 90
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
91
  }
92
};
D
dzhwinter 已提交
93

94 95 96 97
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
98
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
99 100 101 102 103 104 105 106 107 108
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
109 110 111 112 113
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
114
    layout = framework::DataLayout::kMKLDNN;
115 116 117
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
118 119
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
120 121
}

Q
qijun 已提交
122 123 124 125
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

126
  void InferShape(framework::InferShapeContext* ctx) const override {
127
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
128
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
129
  }
130

131
 protected:
132 133 134 135
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
136 137
};

C
chengduo 已提交
138 139 140 141 142 143
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
144 145 146
  }
};

Q
qijun 已提交
147 148 149 150
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

151
  void InferShape(framework::InferShapeContext* ctx) const override {
152 153 154
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
155
  }
156

157
 protected:
158 159
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
160
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
161
  }
Q
qijun 已提交
162 163
};

D
dzhwinter 已提交
164
UNUSED constexpr char SigmoidDoc[] = R"DOC(
165
Sigmoid Activation Operator
K
Kexin Zhao 已提交
166

167
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
Q
qijun 已提交
170

D
dzhwinter 已提交
171
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
172
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

D
dzhwinter 已提交
178
UNUSED constexpr char ExpDoc[] = R"DOC(
179
Exp Operator. Computes exp of x element-wise with a natural number :math:`e` as the base.
K
Kexin Zhao 已提交
180

F
fengjiayi 已提交
181
$out = e^x$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
Q
qijun 已提交
184

D
dzhwinter 已提交
185
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
186
Relu Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = \max(x, 0)$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
K
Kexin Zhao 已提交
191

C
Clementine 已提交
192 193 194 195 196 197 198
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
199
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
200
Tanh Activation Operator.
K
Kexin Zhao 已提交
201

Q
update  
qiaolongfei 已提交
202
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
203

D
dzhwinter 已提交
204
)DOC";
205

D
dzhwinter 已提交
206
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
207
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
208

Y
Yan Chunwei 已提交
209
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
K
Kexin Zhao 已提交
212

D
dzhwinter 已提交
213
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
214
Sqrt Activation Operator.
K
Kexin Zhao 已提交
215

216
.. math:: out=\sqrt x=x^{1/2}
217

218 219
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
220

D
dzhwinter 已提交
221
)DOC";
222

Z
zhoukunsheng 已提交
223 224 225 226 227 228 229 230 231
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
232
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
233
Abs Activation Operator.
K
Kexin Zhao 已提交
234

F
fengjiayi 已提交
235
$out = |x|$
K
Kexin Zhao 已提交
236

D
dzhwinter 已提交
237
)DOC";
238

D
dzhwinter 已提交
239
UNUSED constexpr char CeilDoc[] = R"DOC(
240
Ceil Operator. Computes ceil of x element-wise.
D
dzhwinter 已提交
241

242
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
243

D
dzhwinter 已提交
244
)DOC";
D
dzhwinter 已提交
245

D
dzhwinter 已提交
246
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Floor Activation Operator.

249
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char CosDoc[] = R"DOC(
254
Cosine Operator. Computes cosine of x element-wise.
C
add cos  
chengduoZH 已提交
255 256 257

$out = cos(x)$

D
dzhwinter 已提交
258
)DOC";
C
add cos  
chengduoZH 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261 262 263 264
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
265
)DOC";
C
add sin  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char RoundDoc[] = R"DOC(
268
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
269

270 271 272 273 274 275 276 277 278
.. code-block:: python

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
279

D
dzhwinter 已提交
280
)DOC";
D
dzhwinter 已提交
281

D
dzhwinter 已提交
282
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
283
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
284

285
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
286

D
dzhwinter 已提交
287
)DOC";
288

D
dzhwinter 已提交
289
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
290
Log Activation Operator.
K
Kexin Zhao 已提交
291

F
fengjiayi 已提交
292
$out = \ln(x)$
K
Kexin Zhao 已提交
293 294 295

Natural logarithm of x.

D
dzhwinter 已提交
296 297
)DOC";

D
dzhwinter 已提交
298
UNUSED constexpr char SquareDoc[] = R"DOC(
299
The OP square each elements of the inputs.
D
dzhwinter 已提交
300 301

$out = x^2$
302

D
dzhwinter 已提交
303 304
)DOC";

D
dzhwinter 已提交
305
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
306 307 308 309 310 311
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
312
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
313 314
Softsign Activation Operator.

315
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
316 317 318

)DOC";

T
tink2123 已提交
319 320 321 322 323 324
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
325 326
Arccosine Activation Operator.

T
tink2123 已提交
327
$$out = \cos^{-1}(x)$$
328

T
tink2123 已提交
329 330 331
)DOC");
  }
};
332

T
tink2123 已提交
333 334 335 336 337 338
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
339 340
Arcsine Activation Operator.

T
tink2123 已提交
341
$$out = \sin^{-1}(x)$$
342

T
tink2123 已提交
343 344 345
)DOC");
  }
};
346

T
tink2123 已提交
347 348 349 350 351 352
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
353 354
Arctanh Activation Operator.

T
tink2123 已提交
355
$$out = \tanh^{-1}(x)$$
356

T
tink2123 已提交
357 358 359
)DOC");
  }
};
360

D
dzhwinter 已提交
361
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
362
 public:
Y
Yu Yang 已提交
363
  void Make() override {
W
Wilber 已提交
364 365 366 367 368 369 370 371
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
A
Adam 已提交
372 373 374 375 376 377 378
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
379
    AddComment(R"DOC(
D
dzhwinter 已提交
380
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
381

W
Wilber 已提交
382
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
383 384

)DOC");
385 386 387
  }
};

D
dzhwinter 已提交
388
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
389
 public:
Y
Yu Yang 已提交
390
  void Make() override {
D
dzhwinter 已提交
391 392 393
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
394
    AddComment(R"DOC(
395 396 397
:strong:`Softshrink Activation Operator`

..  math::
398
    out = \begin{cases}
399 400 401 402
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
403 404

)DOC");
K
kexinzhao 已提交
405 406 407
  }
};

D
dzhwinter 已提交
408
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
409
 public:
Y
Yu Yang 已提交
410
  void Make() override {
D
dzhwinter 已提交
411 412
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
413 414
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
415
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
416
    AddComment(R"DOC(
Y
yuyang18 已提交
417
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
418

Y
yuyang18 已提交
419 420 421 422 423 424
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
425 426

)DOC");
427 428 429
  }
};

430 431
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
432
  void Make() override {
433 434 435 436 437 438
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
439 440 441 442
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
443
    AddComment(R"DOC(
K
kexinzhao 已提交
444
BRelu Activation Operator.
K
Kexin Zhao 已提交
445

446
$out = \min(\max(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
447 448

)DOC");
449 450 451 452 453
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
454
  void Make() override {
455
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
456
    AddOutput("Out", "Output of SoftRelu operator");
457 458
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
459
    AddComment(R"DOC(
K
kexinzhao 已提交
460
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
461

T
tensor-tang 已提交
462
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
463 464

)DOC");
465 466 467
  }
};

468 469
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
470
  void Make() override {
471 472 473 474 475 476
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32 or float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``x``.");
477
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
478
    AddComment(R"DOC(
K
kexinzhao 已提交
479
ELU Activation Operator.
K
Kexin Zhao 已提交
480 481 482 483

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
484
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
485 486

)DOC");
487 488 489
  }
};

490 491
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
492
  void Make() override {
Z
zhupengyang 已提交
493 494 495 496 497 498 499 500
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
501
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
502
    AddComment(R"DOC(
K
kexinzhao 已提交
503
Relu6 Activation Operator.
K
Kexin Zhao 已提交
504

Z
zhupengyang 已提交
505
$out = \min(\max(0, x), threshold)$
K
Kexin Zhao 已提交
506 507

)DOC");
508 509 510
  }
};

511 512
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
513
  void Make() override {
514
    AddInput("X", "Input of Pow operator");
515 516 517 518 519
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
520
    AddOutput("Out", "Output of Pow operator");
521
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
522
    AddComment(R"DOC(
K
kexinzhao 已提交
523
Pow Activation Operator.
K
Kexin Zhao 已提交
524

F
fengjiayi 已提交
525
$out = x^{factor}$
K
Kexin Zhao 已提交
526 527

)DOC");
528 529 530 531 532
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
533
  void Make() override {
534 535 536 537 538 539
    AddInput("X",
             "Input of STanh operator."
             " A LoDTensor or Tensor with type float32, float64.");
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
540 541
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
542
    AddComment(R"DOC(
K
kexinzhao 已提交
543
STanh Activation Operator.
K
Kexin Zhao 已提交
544

Y
Yan Chunwei 已提交
545
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
546 547

)DOC");
Q
qijun 已提交
548 549 550
  }
};

551 552
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
553
  void Make() override {
554
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
555
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
556 557
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
558
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
559
    AddComment(R"DOC(
Y
yuyang18 已提交
560
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
561

Y
yuyang18 已提交
562
..  math::
K
Kexin Zhao 已提交
563

Y
yuyang18 已提交
564
    out = \begin{cases}
Y
yuyang18 已提交
565
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
566 567
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
568
)DOC");
569 570 571
  }
};

572 573
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
574
  void Make() override {
575 576 577 578 579
    AddInput("X", "An N-D Tensor with data type float32, float64. ");
    AddOutput("Out", "A Tensor with the same shape as input. ");
    AddAttr<float>("slope",
                   "The slope of the linear approximation of sigmoid. Its "
                   "value MUST BE positive. Default is 0.2. ")
580
        .SetDefault(0.2f);
581 582 583
    AddAttr<float>(
        "offset",
        "The offset of the linear approximation of sigmoid. Default is 0.5. ")
584
        .SetDefault(0.5f);
585
    AddComment(R"DOC(
K
kexinzhao 已提交
586
HardSigmoid Activation Operator.
587

588
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
589
which is much faster than sigmoid.
590

591
$out = \max(0, \min(1, slope * x + offset))$
592

K
Kexin Zhao 已提交
593
)DOC");
594 595 596
  }
};

A
Abhinav Arora 已提交
597 598
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
599
  void Make() override {
A
Abhinav Arora 已提交
600
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
601
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
602 603 604 605
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

606
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
607 608 609 610 611

)DOC");
  }
};

H
huangjun12 已提交
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
638 639 640 641
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
642
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
643 644 645
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
646
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
647 648 649 650 651 652 653 654 655 656 657 658
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

659
template <ActBwdOpFwdDeps kDepValue>
660 661 662 663 664
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
665
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
666
      if (ctx->HasOutput("DX")) {
667 668 669
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
670
      if (ctx->HasOutput("DDOut")) {
671 672 673
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
674
    }
675
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
676
      if (ctx->HasOutput("DOut")) {
677 678 679
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
708 709 710
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
711 712 713 714 715 716 717 718 719 720
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

721 722 723 724 725 726 727 728 729 730 731 732 733 734
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
735
    // input2: ddx
736 737
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
738
    // output: ddy
739 740 741 742 743
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

744 745 746 747 748 749 750 751 752 753 754
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
755 756
    // input1: Out
    op->SetInput("Out", Input("Out"));
757 758 759 760 761 762 763 764 765
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

813 814 815
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
816 817
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
818

819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
888 889 890 891
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
892
namespace plat = paddle::platform;
893

894 895 896 897 898 899 900 901
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
902 903
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
904 905 906

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
907 908 909 910 911 912 913 914 915 916
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
917
                                ops::grad_functor<double>>);
918

919 920
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
921

922
/* ==========================    relu register  ============================= */
923 924 925 926 927
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
928
                  ops::ActivationGradOpInplaceInference,
929
                  ops::ReluDoubleGradMaker);
930 931
REGISTER_OPERATOR(
    relu_grad_grad,
932 933
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
934 935 936 937 938 939 940 941 942 943 944

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
945
/* ========================================================================== */
946

947
/* ======================== leaky relu register  ============================ */
948 949 950 951 952 953
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
954
                  ops::ActivationGradOpInplaceInference,
955
                  ops::LeakyReluDoubleGradMaker);
956 957
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
958 959
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
960

961 962 963 964 965 966 967 968 969 970
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
971 972
/* ========================================================================== */

L
lvmengsi 已提交
973 974 975 976 977 978
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
979
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
980 981 982
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
983 984 985
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
986 987 988 989 990 991 992 993 994 995
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

996 997 998 999 1000 1001 1002
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
1003
                  ops::ActivationGradOpInplaceInference,
1004 1005 1006
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
1007 1008
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */