activation_op.cc 33.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
60
  }
D
dzhwinter 已提交
61

62 63 64 65 66 67 68 69 70 71 72 73 74
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
75 76 77 78
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
79 80 81 82 83 84 85 86 87
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
88
  }
89
};
D
dzhwinter 已提交
90

91 92 93 94
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
95
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
96 97 98 99 100 101 102 103 104 105
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
106 107 108 109 110
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
111
    layout = framework::DataLayout::kMKLDNN;
112 113 114
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
115 116
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
117 118
}

Q
qijun 已提交
119 120 121 122
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

123
  void InferShape(framework::InferShapeContext* ctx) const override {
124
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
125
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
126
  }
127

128
 protected:
129 130 131 132
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
133 134
};

C
chengduo 已提交
135 136 137 138 139 140
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
141 142 143
  }
};

Q
qijun 已提交
144 145 146 147
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

148
  void InferShape(framework::InferShapeContext* ctx) const override {
149 150 151
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
152
  }
153

154
 protected:
155 156
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
157
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
158
  }
Q
qijun 已提交
159 160
};

D
dzhwinter 已提交
161
UNUSED constexpr char SigmoidDoc[] = R"DOC(
162
Sigmoid Activation Operator
K
Kexin Zhao 已提交
163

164
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
165

D
dzhwinter 已提交
166
)DOC";
Q
qijun 已提交
167

D
dzhwinter 已提交
168
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
169
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
170

171
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
176
Exp Activation Operator.
K
Kexin Zhao 已提交
177

F
fengjiayi 已提交
178
$out = e^x$
K
Kexin Zhao 已提交
179

D
dzhwinter 已提交
180
)DOC";
Q
qijun 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
183
Relu Activation Operator.
K
Kexin Zhao 已提交
184

F
fengjiayi 已提交
185
$out = \max(x, 0)$
K
Kexin Zhao 已提交
186

D
dzhwinter 已提交
187
)DOC";
K
Kexin Zhao 已提交
188

C
Clementine 已提交
189 190 191 192 193 194 195
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
196
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
197
Tanh Activation Operator.
K
Kexin Zhao 已提交
198

Q
update  
qiaolongfei 已提交
199
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
200

D
dzhwinter 已提交
201
)DOC";
202

D
dzhwinter 已提交
203
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
204
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
205

Y
Yan Chunwei 已提交
206
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
207

D
dzhwinter 已提交
208
)DOC";
K
Kexin Zhao 已提交
209

D
dzhwinter 已提交
210
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
211
Sqrt Activation Operator.
K
Kexin Zhao 已提交
212

213 214 215
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
216
$out = \sqrt{x}$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
219

Z
zhoukunsheng 已提交
220 221 222 223 224 225 226 227 228
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
229
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
230
Abs Activation Operator.
K
Kexin Zhao 已提交
231

F
fengjiayi 已提交
232
$out = |x|$
K
Kexin Zhao 已提交
233

D
dzhwinter 已提交
234
)DOC";
235

D
dzhwinter 已提交
236
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
237 238
Ceil Activation Operator.

239
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
240

D
dzhwinter 已提交
241
)DOC";
D
dzhwinter 已提交
242

D
dzhwinter 已提交
243
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
244 245
Floor Activation Operator.

246
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
247

D
dzhwinter 已提交
248
)DOC";
D
dzhwinter 已提交
249

D
dzhwinter 已提交
250
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
251
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
252 253 254

$out = cos(x)$

D
dzhwinter 已提交
255
)DOC";
C
add cos  
chengduoZH 已提交
256

D
dzhwinter 已提交
257
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
258 259 260 261
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
262
)DOC";
C
add sin  
chengduoZH 已提交
263

D
dzhwinter 已提交
264
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
265 266
Round Activation Operator.

F
fengjiayi 已提交
267
$out = [x]$
D
dzhwinter 已提交
268

D
dzhwinter 已提交
269
)DOC";
D
dzhwinter 已提交
270

D
dzhwinter 已提交
271
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
272
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
273

274
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
275

D
dzhwinter 已提交
276
)DOC";
277

D
dzhwinter 已提交
278
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
279
Log Activation Operator.
K
Kexin Zhao 已提交
280

F
fengjiayi 已提交
281
$out = \ln(x)$
K
Kexin Zhao 已提交
282 283 284

Natural logarithm of x.

D
dzhwinter 已提交
285 286
)DOC";

D
dzhwinter 已提交
287
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
288 289 290
Square Activation Operator.

$out = x^2$
291

D
dzhwinter 已提交
292 293
)DOC";

D
dzhwinter 已提交
294
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
295 296 297 298 299 300
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
301
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
302 303
Softsign Activation Operator.

304
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
305 306 307

)DOC";

T
tink2123 已提交
308 309 310 311 312 313
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
314 315
Arccosine Activation Operator.

T
tink2123 已提交
316
$$out = \cos^{-1}(x)$$
317

T
tink2123 已提交
318 319 320
)DOC");
  }
};
321

T
tink2123 已提交
322 323 324 325 326 327
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
328 329
Arcsine Activation Operator.

T
tink2123 已提交
330
$$out = \sin^{-1}(x)$$
331

T
tink2123 已提交
332 333 334
)DOC");
  }
};
335

T
tink2123 已提交
336 337 338 339 340 341
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
342 343
Arctanh Activation Operator.

T
tink2123 已提交
344
$$out = \tanh^{-1}(x)$$
345

T
tink2123 已提交
346 347 348
)DOC");
  }
};
349

D
dzhwinter 已提交
350
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
351
 public:
Y
Yu Yang 已提交
352
  void Make() override {
D
dzhwinter 已提交
353 354 355
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
356 357 358 359 360 361 362
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
363
    AddComment(R"DOC(
D
dzhwinter 已提交
364
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
365

D
dzhwinter 已提交
366
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
367 368

)DOC");
369 370 371
  }
};

D
dzhwinter 已提交
372
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
373
 public:
Y
Yu Yang 已提交
374
  void Make() override {
D
dzhwinter 已提交
375 376 377
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
378
    AddComment(R"DOC(
379 380 381
:strong:`Softshrink Activation Operator`

..  math::
382
    out = \begin{cases}
383 384 385 386
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
387 388

)DOC");
K
kexinzhao 已提交
389 390 391
  }
};

D
dzhwinter 已提交
392
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
393
 public:
Y
Yu Yang 已提交
394
  void Make() override {
D
dzhwinter 已提交
395 396
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
397 398
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
399
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
400
    AddComment(R"DOC(
Y
yuyang18 已提交
401
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
402

Y
yuyang18 已提交
403 404 405 406 407 408
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
409 410

)DOC");
411 412 413
  }
};

414 415
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
416
  void Make() override {
417
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
418
    AddOutput("Out", "Output of BRelu operator");
419 420 421 422
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
423
    AddComment(R"DOC(
K
kexinzhao 已提交
424
BRelu Activation Operator.
K
Kexin Zhao 已提交
425

F
fengjiayi 已提交
426
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
427 428

)DOC");
429 430 431 432 433
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
434
  void Make() override {
435
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
436
    AddOutput("Out", "Output of SoftRelu operator");
437 438
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
441

T
tensor-tang 已提交
442
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
443 444

)DOC");
445 446 447
  }
};

448 449
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
450
  void Make() override {
K
Kexin Zhao 已提交
451
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
452
    AddOutput("Out", "Output of ELU operator");
453
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
454
    AddComment(R"DOC(
K
kexinzhao 已提交
455
ELU Activation Operator.
K
Kexin Zhao 已提交
456 457 458 459

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
460
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
461 462

)DOC");
463 464 465
  }
};

466 467
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
468
  void Make() override {
469
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
470
    AddOutput("Out", "Output of Relu6 operator");
471 472
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
473
    AddComment(R"DOC(
K
kexinzhao 已提交
474
Relu6 Activation Operator.
K
Kexin Zhao 已提交
475

F
fengjiayi 已提交
476
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
477 478

)DOC");
479 480 481
  }
};

482 483
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
484
  void Make() override {
485
    AddInput("X", "Input of Pow operator");
486 487 488 489 490
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
491
    AddOutput("Out", "Output of Pow operator");
492
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
493
    AddComment(R"DOC(
K
kexinzhao 已提交
494
Pow Activation Operator.
K
Kexin Zhao 已提交
495

F
fengjiayi 已提交
496
$out = x^{factor}$
K
Kexin Zhao 已提交
497 498

)DOC");
499 500 501 502 503
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
504
  void Make() override {
505
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
506
    AddOutput("Out", "Output of STanh operator");
507 508 509 510
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
511
    AddComment(R"DOC(
K
kexinzhao 已提交
512
STanh Activation Operator.
K
Kexin Zhao 已提交
513

Y
Yan Chunwei 已提交
514
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
515 516

)DOC");
Q
qijun 已提交
517 518 519
  }
};

520 521
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
522
  void Make() override {
523
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
524
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
525 526
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
527
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
528
    AddComment(R"DOC(
Y
yuyang18 已提交
529
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
530

Y
yuyang18 已提交
531
..  math::
K
Kexin Zhao 已提交
532

Y
yuyang18 已提交
533
    out = \begin{cases}
Y
yuyang18 已提交
534
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
535 536
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
537
)DOC");
538 539 540
  }
};

541 542
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
543
  void Make() override {
544
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
545
    AddOutput("Out", "Output of HardSigmoid operator");
546 547 548 549
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
550
    AddComment(R"DOC(
K
kexinzhao 已提交
551
HardSigmoid Activation Operator.
552

553
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
554
which is much faster than sigmoid.
555

F
fengjiayi 已提交
556
$out = \max(0, \min(1, slope * x + shift))$
557 558

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
559
The default slope and shift are set according to the above reference.
560 561
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
562
)DOC");
563 564 565
  }
};

A
Abhinav Arora 已提交
566 567
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
568
  void Make() override {
A
Abhinav Arora 已提交
569
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
570
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
571 572 573 574
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

575
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
576 577 578 579 580

)DOC");
  }
};

H
huangjun12 已提交
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
607 608 609 610
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
611
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
612 613 614
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
615
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
616 617 618 619 620 621 622 623 624 625 626 627
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

628
template <ActBwdOpFwdDeps kDepValue>
629 630 631 632 633
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
634
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
635
      if (ctx->HasOutput("DX")) {
636 637 638
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
639
      if (ctx->HasOutput("DDOut")) {
640 641 642
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
643
    }
644
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
645
      if (ctx->HasOutput("DOut")) {
646 647 648
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
677 678 679
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
680 681 682 683 684 685 686 687 688 689
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

690 691 692 693 694 695 696 697 698 699 700 701 702 703
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
704
    // input2: ddx
705 706
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
707
    // output: ddy
708 709 710 711 712
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

713 714 715 716 717 718 719 720 721 722 723
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
724 725
    // input1: Out
    op->SetInput("Out", Input("Out"));
726 727 728 729 730 731 732 733 734
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

782 783 784
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
785 786
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
787

788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
857 858 859 860
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
861
namespace plat = paddle::platform;
862

863 864 865 866 867 868 869 870
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
871 872
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
873 874 875

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
876 877 878 879 880 881 882 883 884 885
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
886
                                ops::grad_functor<double>>);
887

888 889
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
890

891
/* ==========================    relu register  ============================= */
892 893 894 895 896
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
897
                  ops::ActivationGradOpInplaceInference,
898
                  ops::ReluDoubleGradMaker);
899 900
REGISTER_OPERATOR(
    relu_grad_grad,
901 902
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
903 904 905 906 907 908 909 910 911 912 913

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
914
/* ========================================================================== */
915

916
/* ======================== leaky relu register  ============================ */
917 918 919 920 921 922
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
923
                  ops::ActivationGradOpInplaceInference,
924
                  ops::LeakyReluDoubleGradMaker);
925 926
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
927 928
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
929

930 931 932 933 934 935 936 937 938 939
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
940 941
/* ========================================================================== */

L
lvmengsi 已提交
942 943 944 945 946 947
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
948
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
949 950 951
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
952 953 954
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
955 956 957 958 959 960 961 962 963 964
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

965 966 967 968 969 970 971
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
972
                  ops::ActivationGradOpInplaceInference,
973 974 975
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
976 977
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
978 979 980 981 982 983 984 985 986 987 988 989 990

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */