activation_op.cc 33.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
44 45 46 47 48
      AddInput("X", "Input of " #OP_NAME                                     \
                    " operator, an N-D Tensor, with data type float32, "     \
                    "float64 or float16.");                                  \
      AddOutput("Out", "Output of " #OP_NAME                                 \
                       " operator, a Tensor with shape same as input.");     \
49 50 51 52 53 54 55 56 57 58 59 60 61 62
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
63
  }
D
dzhwinter 已提交
64

65 66 67 68 69 70 71 72 73 74 75 76 77
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
78 79 80 81
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
82 83 84 85 86 87 88 89 90
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
91
  }
92
};
D
dzhwinter 已提交
93

94 95 96 97
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
98
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
99 100 101 102 103 104 105 106 107 108
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
109 110 111 112 113
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
114
    layout = framework::DataLayout::kMKLDNN;
115 116 117
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
118 119
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
120 121
}

Q
qijun 已提交
122 123 124 125
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

126
  void InferShape(framework::InferShapeContext* ctx) const override {
127
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
128
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
129
  }
130

131
 protected:
132 133 134 135
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
136 137
};

C
chengduo 已提交
138 139 140 141 142 143
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
144 145 146
  }
};

Q
qijun 已提交
147 148 149 150
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

151
  void InferShape(framework::InferShapeContext* ctx) const override {
152 153 154
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
155
  }
156

157
 protected:
158 159
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
160
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
161
  }
Q
qijun 已提交
162 163
};

D
dzhwinter 已提交
164
UNUSED constexpr char SigmoidDoc[] = R"DOC(
165
Sigmoid Activation Operator
K
Kexin Zhao 已提交
166

167
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
Q
qijun 已提交
170

D
dzhwinter 已提交
171
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
172
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

D
dzhwinter 已提交
178
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
179
Exp Activation Operator.
K
Kexin Zhao 已提交
180

F
fengjiayi 已提交
181
$out = e^x$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
Q
qijun 已提交
184

D
dzhwinter 已提交
185
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
186
Relu Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = \max(x, 0)$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
K
Kexin Zhao 已提交
191

C
Clementine 已提交
192 193 194 195 196 197 198
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
199
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
200
Tanh Activation Operator.
K
Kexin Zhao 已提交
201

Q
update  
qiaolongfei 已提交
202
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
203

D
dzhwinter 已提交
204
)DOC";
205

D
dzhwinter 已提交
206
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
207
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
208

Y
Yan Chunwei 已提交
209
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
K
Kexin Zhao 已提交
212

D
dzhwinter 已提交
213
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
214
Sqrt Activation Operator.
K
Kexin Zhao 已提交
215

216
.. math:: out=\sqrt x=x^{1/2}
217

218 219
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
220

D
dzhwinter 已提交
221
)DOC";
222

Z
zhoukunsheng 已提交
223 224 225 226 227 228 229 230 231
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
232
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
233
Abs Activation Operator.
K
Kexin Zhao 已提交
234

F
fengjiayi 已提交
235
$out = |x|$
K
Kexin Zhao 已提交
236

D
dzhwinter 已提交
237
)DOC";
238

D
dzhwinter 已提交
239
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
240 241
Ceil Activation Operator.

242
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
243

D
dzhwinter 已提交
244
)DOC";
D
dzhwinter 已提交
245

D
dzhwinter 已提交
246
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Floor Activation Operator.

249
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
254
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
255 256 257

$out = cos(x)$

D
dzhwinter 已提交
258
)DOC";
C
add cos  
chengduoZH 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261 262 263 264
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
265
)DOC";
C
add sin  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char RoundDoc[] = R"DOC(
268
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
269

270 271 272 273 274 275 276 277 278
.. code-block:: python

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
279

D
dzhwinter 已提交
280
)DOC";
D
dzhwinter 已提交
281

D
dzhwinter 已提交
282
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
283
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
284

285
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
286

D
dzhwinter 已提交
287
)DOC";
288

D
dzhwinter 已提交
289
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
290
Log Activation Operator.
K
Kexin Zhao 已提交
291

F
fengjiayi 已提交
292
$out = \ln(x)$
K
Kexin Zhao 已提交
293 294 295

Natural logarithm of x.

D
dzhwinter 已提交
296 297
)DOC";

D
dzhwinter 已提交
298
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
299 300 301
Square Activation Operator.

$out = x^2$
302

D
dzhwinter 已提交
303 304
)DOC";

D
dzhwinter 已提交
305
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
306 307 308 309 310 311
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
312
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
313 314
Softsign Activation Operator.

315
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
316 317 318

)DOC";

T
tink2123 已提交
319 320 321 322 323 324
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
325 326
Arccosine Activation Operator.

T
tink2123 已提交
327
$$out = \cos^{-1}(x)$$
328

T
tink2123 已提交
329 330 331
)DOC");
  }
};
332

T
tink2123 已提交
333 334 335 336 337 338
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
339 340
Arcsine Activation Operator.

T
tink2123 已提交
341
$$out = \sin^{-1}(x)$$
342

T
tink2123 已提交
343 344 345
)DOC");
  }
};
346

T
tink2123 已提交
347 348 349 350 351 352
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
353 354
Arctanh Activation Operator.

T
tink2123 已提交
355
$$out = \tanh^{-1}(x)$$
356

T
tink2123 已提交
357 358 359
)DOC");
  }
};
360

D
dzhwinter 已提交
361
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
362
 public:
Y
Yu Yang 已提交
363
  void Make() override {
W
Wilber 已提交
364 365 366 367 368 369 370 371
    AddInput("X",
             "A LoDTensor or Tensor representing preactivation values. Must be "
             "one of the following types: float32, float64.");
    AddOutput(
        "Out",
        "A LoDTensor or Tensor with the same type and size as that of x.");
    AddAttr<float>("alpha", "Slope of the activation function at x < 0.")
        .SetDefault(0.02f);
A
Adam 已提交
372 373 374 375 376 377 378
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
379
    AddComment(R"DOC(
D
dzhwinter 已提交
380
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
381

W
Wilber 已提交
382
$$out = \max(x, \alpha * x)$$
K
Kexin Zhao 已提交
383 384

)DOC");
385 386 387
  }
};

D
dzhwinter 已提交
388
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
389
 public:
Y
Yu Yang 已提交
390
  void Make() override {
D
dzhwinter 已提交
391 392 393
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
394
    AddComment(R"DOC(
395 396 397
:strong:`Softshrink Activation Operator`

..  math::
398
    out = \begin{cases}
399 400 401 402
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
403 404

)DOC");
K
kexinzhao 已提交
405 406 407
  }
};

D
dzhwinter 已提交
408
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
409
 public:
Y
Yu Yang 已提交
410
  void Make() override {
D
dzhwinter 已提交
411 412
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
413 414
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
415
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
416
    AddComment(R"DOC(
Y
yuyang18 已提交
417
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
418

Y
yuyang18 已提交
419 420 421 422 423 424
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
425 426

)DOC");
427 428 429
  }
};

430 431
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
432
  void Make() override {
433
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
434
    AddOutput("Out", "Output of BRelu operator");
435 436 437 438
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
BRelu Activation Operator.
K
Kexin Zhao 已提交
441

F
fengjiayi 已提交
442
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
443 444

)DOC");
445 446 447 448 449
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
450
  void Make() override {
451
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
452
    AddOutput("Out", "Output of SoftRelu operator");
453 454
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
455
    AddComment(R"DOC(
K
kexinzhao 已提交
456
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
457

T
tensor-tang 已提交
458
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
459 460

)DOC");
461 462 463
  }
};

464 465
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
466
  void Make() override {
K
Kexin Zhao 已提交
467
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
468
    AddOutput("Out", "Output of ELU operator");
469
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
470
    AddComment(R"DOC(
K
kexinzhao 已提交
471
ELU Activation Operator.
K
Kexin Zhao 已提交
472 473 474 475

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
476
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
477 478

)DOC");
479 480 481
  }
};

482 483
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
484
  void Make() override {
485
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
486
    AddOutput("Out", "Output of Relu6 operator");
487 488
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
489
    AddComment(R"DOC(
K
kexinzhao 已提交
490
Relu6 Activation Operator.
K
Kexin Zhao 已提交
491

F
fengjiayi 已提交
492
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
493 494

)DOC");
495 496 497
  }
};

498 499
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
500
  void Make() override {
501
    AddInput("X", "Input of Pow operator");
502 503 504 505 506
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
507
    AddOutput("Out", "Output of Pow operator");
508
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
509
    AddComment(R"DOC(
K
kexinzhao 已提交
510
Pow Activation Operator.
K
Kexin Zhao 已提交
511

F
fengjiayi 已提交
512
$out = x^{factor}$
K
Kexin Zhao 已提交
513 514

)DOC");
515 516 517 518 519
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
520
  void Make() override {
521 522 523 524 525 526
    AddInput("X",
             "Input of STanh operator."
             " A LoDTensor or Tensor with type float32, float64.");
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
527 528
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
529
    AddComment(R"DOC(
K
kexinzhao 已提交
530
STanh Activation Operator.
K
Kexin Zhao 已提交
531

Y
Yan Chunwei 已提交
532
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
533 534

)DOC");
Q
qijun 已提交
535 536 537
  }
};

538 539
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
540
  void Make() override {
541
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
542
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
543 544
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
545
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
546
    AddComment(R"DOC(
Y
yuyang18 已提交
547
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
548

Y
yuyang18 已提交
549
..  math::
K
Kexin Zhao 已提交
550

Y
yuyang18 已提交
551
    out = \begin{cases}
Y
yuyang18 已提交
552
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
553 554
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
555
)DOC");
556 557 558
  }
};

559 560
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
561
  void Make() override {
562
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
563
    AddOutput("Out", "Output of HardSigmoid operator");
564 565 566 567
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
568
    AddComment(R"DOC(
K
kexinzhao 已提交
569
HardSigmoid Activation Operator.
570

571
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
572
which is much faster than sigmoid.
573

F
fengjiayi 已提交
574
$out = \max(0, \min(1, slope * x + shift))$
575 576

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
577
The default slope and shift are set according to the above reference.
578 579
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
580
)DOC");
581 582 583
  }
};

A
Abhinav Arora 已提交
584 585
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
586
  void Make() override {
A
Abhinav Arora 已提交
587
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
588
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
589 590 591 592
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

593
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
594 595 596 597 598

)DOC");
  }
};

H
huangjun12 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
625 626 627 628
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
629
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
630 631 632
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
633
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
634 635 636 637 638 639 640 641 642 643 644 645
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

646
template <ActBwdOpFwdDeps kDepValue>
647 648 649 650 651
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
652
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
653
      if (ctx->HasOutput("DX")) {
654 655 656
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
657
      if (ctx->HasOutput("DDOut")) {
658 659 660
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
661
    }
662
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
663
      if (ctx->HasOutput("DOut")) {
664 665 666
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
695 696 697
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
698 699 700 701 702 703 704 705 706 707
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

708 709 710 711 712 713 714 715 716 717 718 719 720 721
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
722
    // input2: ddx
723 724
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
725
    // output: ddy
726 727 728 729 730
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

731 732 733 734 735 736 737 738 739 740 741
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
742 743
    // input1: Out
    op->SetInput("Out", Input("Out"));
744 745 746 747 748 749 750 751 752
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

800 801 802
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
803 804
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
805

806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
875 876 877 878
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
879
namespace plat = paddle::platform;
880

881 882 883 884 885 886 887 888
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
889 890
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
891 892 893

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
894 895 896 897 898 899 900 901 902 903
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
904
                                ops::grad_functor<double>>);
905

906 907
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
908

909
/* ==========================    relu register  ============================= */
910 911 912 913 914
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
915
                  ops::ActivationGradOpInplaceInference,
916
                  ops::ReluDoubleGradMaker);
917 918
REGISTER_OPERATOR(
    relu_grad_grad,
919 920
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
921 922 923 924 925 926 927 928 929 930 931

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
932
/* ========================================================================== */
933

934
/* ======================== leaky relu register  ============================ */
935 936 937 938 939 940
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
941
                  ops::ActivationGradOpInplaceInference,
942
                  ops::LeakyReluDoubleGradMaker);
943 944
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
945 946
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
947

948 949 950 951 952 953 954 955 956 957
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
958 959
/* ========================================================================== */

L
lvmengsi 已提交
960 961 962 963 964 965
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
966
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
967 968 969
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
970 971 972
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
973 974 975 976 977 978 979 980 981 982
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

983 984 985 986 987 988 989
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
990
                  ops::ActivationGradOpInplaceInference,
991 992 993
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
994 995
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */