activation_op.cc 33.4 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
44 45 46 47 48
      AddInput("X", "Input of " #OP_NAME                                     \
                    " operator, an N-D Tensor, with data type float32, "     \
                    "float64 or float16.");                                  \
      AddOutput("Out", "Output of " #OP_NAME                                 \
                       " operator, a Tensor with shape same as input.");     \
49 50 51 52 53 54 55 56 57 58 59 60 61 62
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
63
  }
D
dzhwinter 已提交
64

65 66 67 68 69 70 71 72 73 74 75 76 77
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
78 79 80 81
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
82 83 84 85 86 87 88 89 90
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
91
  }
92
};
D
dzhwinter 已提交
93

94 95 96 97
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
98
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
99 100 101 102 103 104 105 106 107 108
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
109 110 111 112 113
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
114
    layout = framework::DataLayout::kMKLDNN;
115 116 117
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
118 119
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
120 121
}

Q
qijun 已提交
122 123 124 125
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

126
  void InferShape(framework::InferShapeContext* ctx) const override {
127
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
128
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
129
  }
130

131
 protected:
132 133 134 135
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
136 137
};

C
chengduo 已提交
138 139 140 141 142 143
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
144 145 146
  }
};

Q
qijun 已提交
147 148 149 150
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

151
  void InferShape(framework::InferShapeContext* ctx) const override {
152 153 154
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
155
  }
156

157
 protected:
158 159
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
160
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
161
  }
Q
qijun 已提交
162 163
};

D
dzhwinter 已提交
164
UNUSED constexpr char SigmoidDoc[] = R"DOC(
165
Sigmoid Activation Operator
K
Kexin Zhao 已提交
166

167
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
Q
qijun 已提交
170

D
dzhwinter 已提交
171
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
172
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

D
dzhwinter 已提交
178
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
179
Exp Activation Operator.
K
Kexin Zhao 已提交
180

F
fengjiayi 已提交
181
$out = e^x$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
Q
qijun 已提交
184

D
dzhwinter 已提交
185
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
186
Relu Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = \max(x, 0)$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
K
Kexin Zhao 已提交
191

C
Clementine 已提交
192 193 194 195 196 197 198
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
199
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
200
Tanh Activation Operator.
K
Kexin Zhao 已提交
201

Q
update  
qiaolongfei 已提交
202
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
203

D
dzhwinter 已提交
204
)DOC";
205

D
dzhwinter 已提交
206
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
207
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
208

Y
Yan Chunwei 已提交
209
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
K
Kexin Zhao 已提交
212

D
dzhwinter 已提交
213
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
214
Sqrt Activation Operator.
K
Kexin Zhao 已提交
215

216
.. math:: out=\sqrt x=x^{1/2}
217

218 219
**Note**:
  input value must be greater than or equal to zero.
K
Kexin Zhao 已提交
220

D
dzhwinter 已提交
221
)DOC";
222

Z
zhoukunsheng 已提交
223 224 225 226 227 228 229 230 231
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
232
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
233
Abs Activation Operator.
K
Kexin Zhao 已提交
234

F
fengjiayi 已提交
235
$out = |x|$
K
Kexin Zhao 已提交
236

D
dzhwinter 已提交
237
)DOC";
238

D
dzhwinter 已提交
239
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
240 241
Ceil Activation Operator.

242
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
243

D
dzhwinter 已提交
244
)DOC";
D
dzhwinter 已提交
245

D
dzhwinter 已提交
246
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Floor Activation Operator.

249
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
254
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
255 256 257

$out = cos(x)$

D
dzhwinter 已提交
258
)DOC";
C
add cos  
chengduoZH 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261 262 263 264
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
265
)DOC";
C
add sin  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char RoundDoc[] = R"DOC(
268
The OP rounds the values in the input to the nearest integer value.
D
dzhwinter 已提交
269

270 271 272 273 274 275 276 277 278
.. code-block:: python

  input:
    x.shape = [4]
    x.data = [1.2, -0.9, 3.4, 0.9]

  output:
    out.shape = [4]
    out.data = [1., -1., 3., 1.]
D
dzhwinter 已提交
279

D
dzhwinter 已提交
280
)DOC";
D
dzhwinter 已提交
281

D
dzhwinter 已提交
282
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
283
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
284

285
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
286

D
dzhwinter 已提交
287
)DOC";
288

D
dzhwinter 已提交
289
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
290
Log Activation Operator.
K
Kexin Zhao 已提交
291

F
fengjiayi 已提交
292
$out = \ln(x)$
K
Kexin Zhao 已提交
293 294 295

Natural logarithm of x.

D
dzhwinter 已提交
296 297
)DOC";

D
dzhwinter 已提交
298
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
299 300 301
Square Activation Operator.

$out = x^2$
302

D
dzhwinter 已提交
303 304
)DOC";

D
dzhwinter 已提交
305
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
306 307 308 309 310 311
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
312
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
313 314
Softsign Activation Operator.

315
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
316 317 318

)DOC";

T
tink2123 已提交
319 320 321 322 323 324
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
325 326
Arccosine Activation Operator.

T
tink2123 已提交
327
$$out = \cos^{-1}(x)$$
328

T
tink2123 已提交
329 330 331
)DOC");
  }
};
332

T
tink2123 已提交
333 334 335 336 337 338
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
339 340
Arcsine Activation Operator.

T
tink2123 已提交
341
$$out = \sin^{-1}(x)$$
342

T
tink2123 已提交
343 344 345
)DOC");
  }
};
346

T
tink2123 已提交
347 348 349 350 351 352
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
353 354
Arctanh Activation Operator.

T
tink2123 已提交
355
$$out = \tanh^{-1}(x)$$
356

T
tink2123 已提交
357 358 359
)DOC");
  }
};
360

D
dzhwinter 已提交
361
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
362
 public:
Y
Yu Yang 已提交
363
  void Make() override {
D
dzhwinter 已提交
364 365 366
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
367 368 369 370 371 372 373
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
374
    AddComment(R"DOC(
D
dzhwinter 已提交
375
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
376

D
dzhwinter 已提交
377
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
378 379

)DOC");
380 381 382
  }
};

D
dzhwinter 已提交
383
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
384
 public:
Y
Yu Yang 已提交
385
  void Make() override {
D
dzhwinter 已提交
386 387 388
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
389
    AddComment(R"DOC(
390 391 392
:strong:`Softshrink Activation Operator`

..  math::
393
    out = \begin{cases}
394 395 396 397
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
398 399

)DOC");
K
kexinzhao 已提交
400 401 402
  }
};

D
dzhwinter 已提交
403
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
404
 public:
Y
Yu Yang 已提交
405
  void Make() override {
D
dzhwinter 已提交
406 407
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
408 409
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
410
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
411
    AddComment(R"DOC(
Y
yuyang18 已提交
412
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
413

Y
yuyang18 已提交
414 415 416 417 418 419
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
420 421

)DOC");
422 423 424
  }
};

425 426
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
427
  void Make() override {
428
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
429
    AddOutput("Out", "Output of BRelu operator");
430 431 432 433
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
434
    AddComment(R"DOC(
K
kexinzhao 已提交
435
BRelu Activation Operator.
K
Kexin Zhao 已提交
436

F
fengjiayi 已提交
437
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
438 439

)DOC");
440 441 442 443 444
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
445
  void Make() override {
446
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
447
    AddOutput("Out", "Output of SoftRelu operator");
448 449
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
450
    AddComment(R"DOC(
K
kexinzhao 已提交
451
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
452

T
tensor-tang 已提交
453
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
454 455

)DOC");
456 457 458
  }
};

459 460
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
461
  void Make() override {
K
Kexin Zhao 已提交
462
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
463
    AddOutput("Out", "Output of ELU operator");
464
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
465
    AddComment(R"DOC(
K
kexinzhao 已提交
466
ELU Activation Operator.
K
Kexin Zhao 已提交
467 468 469 470

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
471
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
472 473

)DOC");
474 475 476
  }
};

477 478
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
479
  void Make() override {
480
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
481
    AddOutput("Out", "Output of Relu6 operator");
482 483
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
484
    AddComment(R"DOC(
K
kexinzhao 已提交
485
Relu6 Activation Operator.
K
Kexin Zhao 已提交
486

F
fengjiayi 已提交
487
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
488 489

)DOC");
490 491 492
  }
};

493 494
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
495
  void Make() override {
496
    AddInput("X", "Input of Pow operator");
497 498 499 500 501
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
502
    AddOutput("Out", "Output of Pow operator");
503
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
504
    AddComment(R"DOC(
K
kexinzhao 已提交
505
Pow Activation Operator.
K
Kexin Zhao 已提交
506

F
fengjiayi 已提交
507
$out = x^{factor}$
K
Kexin Zhao 已提交
508 509

)DOC");
510 511 512 513 514
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
515
  void Make() override {
516 517 518 519 520 521
    AddInput("X",
             "Input of STanh operator."
             " A LoDTensor or Tensor with type float32, float64.");
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
522 523
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
524
    AddComment(R"DOC(
K
kexinzhao 已提交
525
STanh Activation Operator.
K
Kexin Zhao 已提交
526

Y
Yan Chunwei 已提交
527
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
528 529

)DOC");
Q
qijun 已提交
530 531 532
  }
};

533 534
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
535
  void Make() override {
536
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
537
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
538 539
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
540
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
541
    AddComment(R"DOC(
Y
yuyang18 已提交
542
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
543

Y
yuyang18 已提交
544
..  math::
K
Kexin Zhao 已提交
545

Y
yuyang18 已提交
546
    out = \begin{cases}
Y
yuyang18 已提交
547
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
548 549
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
550
)DOC");
551 552 553
  }
};

554 555
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
556
  void Make() override {
557
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
558
    AddOutput("Out", "Output of HardSigmoid operator");
559 560 561 562
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
563
    AddComment(R"DOC(
K
kexinzhao 已提交
564
HardSigmoid Activation Operator.
565

566
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
567
which is much faster than sigmoid.
568

F
fengjiayi 已提交
569
$out = \max(0, \min(1, slope * x + shift))$
570 571

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
572
The default slope and shift are set according to the above reference.
573 574
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
575
)DOC");
576 577 578
  }
};

A
Abhinav Arora 已提交
579 580
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
581
  void Make() override {
A
Abhinav Arora 已提交
582
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
583
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
584 585 586 587
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

588
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
589 590 591 592 593

)DOC");
  }
};

H
huangjun12 已提交
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
620 621 622 623
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
624
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
625 626 627
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
628
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
629 630 631 632 633 634 635 636 637 638 639 640
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

641
template <ActBwdOpFwdDeps kDepValue>
642 643 644 645 646
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
647
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
648
      if (ctx->HasOutput("DX")) {
649 650 651
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
652
      if (ctx->HasOutput("DDOut")) {
653 654 655
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
656
    }
657
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
658
      if (ctx->HasOutput("DOut")) {
659 660 661
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
690 691 692
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
693 694 695 696 697 698 699 700 701 702
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

703 704 705 706 707 708 709 710 711 712 713 714 715 716
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
717
    // input2: ddx
718 719
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
720
    // output: ddy
721 722 723 724 725
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

726 727 728 729 730 731 732 733 734 735 736
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
737 738
    // input1: Out
    op->SetInput("Out", Input("Out"));
739 740 741 742 743 744 745 746 747
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

795 796 797
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
798 799
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
800

801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
870 871 872 873
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
874
namespace plat = paddle::platform;
875

876 877 878 879 880 881 882 883
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
884 885
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
886 887 888

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
889 890 891 892 893 894 895 896 897 898
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
899
                                ops::grad_functor<double>>);
900

901 902
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
903

904
/* ==========================    relu register  ============================= */
905 906 907 908 909
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
910
                  ops::ActivationGradOpInplaceInference,
911
                  ops::ReluDoubleGradMaker);
912 913
REGISTER_OPERATOR(
    relu_grad_grad,
914 915
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
916 917 918 919 920 921 922 923 924 925 926

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
927
/* ========================================================================== */
928

929
/* ======================== leaky relu register  ============================ */
930 931 932 933 934 935
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
936
                  ops::ActivationGradOpInplaceInference,
937
                  ops::LeakyReluDoubleGradMaker);
938 939
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
940 941
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
942

943 944 945 946 947 948 949 950 951 952
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
953 954
/* ========================================================================== */

L
lvmengsi 已提交
955 956 957 958 959 960
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
961
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
962 963 964
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
965 966 967
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
968 969 970 971 972 973 974 975 976 977
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

978 979 980 981 982 983 984
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
985
                  ops::ActivationGradOpInplaceInference,
986 987 988
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
989 990
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */