activation_op.cc 33.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
60
  }
D
dzhwinter 已提交
61

62 63 64 65 66 67 68 69 70 71 72 73 74
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
75 76 77 78
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
79 80 81 82 83 84 85 86 87
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
88
  }
89
};
D
dzhwinter 已提交
90

91 92 93 94
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
95
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
96 97 98 99 100 101 102 103 104 105
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
106 107 108 109 110
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
111
    layout = framework::DataLayout::kMKLDNN;
112 113 114
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
115 116
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
117 118
}

Q
qijun 已提交
119 120 121 122
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

123
  void InferShape(framework::InferShapeContext* ctx) const override {
124
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
125
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
126
  }
127

128
 protected:
129 130 131 132
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
133 134
};

C
chengduo 已提交
135 136 137 138 139 140
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
141 142 143
  }
};

Q
qijun 已提交
144 145 146 147
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

148
  void InferShape(framework::InferShapeContext* ctx) const override {
149 150 151
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
152
  }
153

154
 protected:
155 156
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
157
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
158
  }
Q
qijun 已提交
159 160
};

D
dzhwinter 已提交
161
UNUSED constexpr char SigmoidDoc[] = R"DOC(
162
Sigmoid Activation Operator
K
Kexin Zhao 已提交
163

164
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
165

D
dzhwinter 已提交
166
)DOC";
Q
qijun 已提交
167

D
dzhwinter 已提交
168
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
169
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
170

171
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
176
Exp Activation Operator.
K
Kexin Zhao 已提交
177

F
fengjiayi 已提交
178
$out = e^x$
K
Kexin Zhao 已提交
179

D
dzhwinter 已提交
180
)DOC";
Q
qijun 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
183
Relu Activation Operator.
K
Kexin Zhao 已提交
184

F
fengjiayi 已提交
185
$out = \max(x, 0)$
K
Kexin Zhao 已提交
186

D
dzhwinter 已提交
187
)DOC";
K
Kexin Zhao 已提交
188

C
Clementine 已提交
189 190 191 192 193 194 195
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
196
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
197
Tanh Activation Operator.
K
Kexin Zhao 已提交
198

Q
update  
qiaolongfei 已提交
199
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
200

D
dzhwinter 已提交
201
)DOC";
202

D
dzhwinter 已提交
203
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
204
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
205

Y
Yan Chunwei 已提交
206
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
207

D
dzhwinter 已提交
208
)DOC";
K
Kexin Zhao 已提交
209

D
dzhwinter 已提交
210
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
211
Sqrt Activation Operator.
K
Kexin Zhao 已提交
212

213 214 215
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
216
$out = \sqrt{x}$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
219

Z
zhoukunsheng 已提交
220 221 222 223 224 225 226 227 228
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
229
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
230
Abs Activation Operator.
K
Kexin Zhao 已提交
231

F
fengjiayi 已提交
232
$out = |x|$
K
Kexin Zhao 已提交
233

D
dzhwinter 已提交
234
)DOC";
235

D
dzhwinter 已提交
236
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
237 238
Ceil Activation Operator.

239
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
240

D
dzhwinter 已提交
241
)DOC";
D
dzhwinter 已提交
242

D
dzhwinter 已提交
243
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
244 245
Floor Activation Operator.

246
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
247

D
dzhwinter 已提交
248
)DOC";
D
dzhwinter 已提交
249

D
dzhwinter 已提交
250
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
251
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
252 253 254

$out = cos(x)$

D
dzhwinter 已提交
255
)DOC";
C
add cos  
chengduoZH 已提交
256

D
dzhwinter 已提交
257
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
258 259 260 261
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
262
)DOC";
C
add sin  
chengduoZH 已提交
263

D
dzhwinter 已提交
264
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
265 266
Round Activation Operator.

F
fengjiayi 已提交
267
$out = [x]$
D
dzhwinter 已提交
268

D
dzhwinter 已提交
269
)DOC";
D
dzhwinter 已提交
270

D
dzhwinter 已提交
271
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
272
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
273

274
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
275

D
dzhwinter 已提交
276
)DOC";
277

D
dzhwinter 已提交
278
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
279
Log Activation Operator.
K
Kexin Zhao 已提交
280

F
fengjiayi 已提交
281
$out = \ln(x)$
K
Kexin Zhao 已提交
282 283 284

Natural logarithm of x.

D
dzhwinter 已提交
285 286
)DOC";

D
dzhwinter 已提交
287
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
288 289 290
Square Activation Operator.

$out = x^2$
291

D
dzhwinter 已提交
292 293
)DOC";

D
dzhwinter 已提交
294
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
295 296 297 298 299 300
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
301
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
302 303
Softsign Activation Operator.

304
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
305 306 307

)DOC";

T
tink2123 已提交
308 309 310 311 312 313
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
314 315
Arccosine Activation Operator.

T
tink2123 已提交
316
$$out = \cos^{-1}(x)$$
317

T
tink2123 已提交
318 319 320
)DOC");
  }
};
321

T
tink2123 已提交
322 323 324 325 326 327
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
328 329
Arcsine Activation Operator.

T
tink2123 已提交
330
$$out = \sin^{-1}(x)$$
331

T
tink2123 已提交
332 333 334
)DOC");
  }
};
335

T
tink2123 已提交
336 337 338 339 340 341
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
342 343
Arctanh Activation Operator.

T
tink2123 已提交
344
$$out = \tanh^{-1}(x)$$
345

T
tink2123 已提交
346 347 348
)DOC");
  }
};
349

D
dzhwinter 已提交
350
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
351
 public:
Y
Yu Yang 已提交
352
  void Make() override {
D
dzhwinter 已提交
353 354 355
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
356 357 358 359 360 361 362
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
363
    AddComment(R"DOC(
D
dzhwinter 已提交
364
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
365

D
dzhwinter 已提交
366
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
367 368

)DOC");
369 370 371
  }
};

D
dzhwinter 已提交
372
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
373
 public:
Y
Yu Yang 已提交
374
  void Make() override {
D
dzhwinter 已提交
375 376 377
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
378
    AddComment(R"DOC(
379 380 381
:strong:`Softshrink Activation Operator`

..  math::
382
    out = \begin{cases}
383 384 385 386
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
387 388

)DOC");
K
kexinzhao 已提交
389 390 391
  }
};

D
dzhwinter 已提交
392
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
393
 public:
Y
Yu Yang 已提交
394
  void Make() override {
D
dzhwinter 已提交
395 396
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
397 398
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
399
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
400
    AddComment(R"DOC(
Y
yuyang18 已提交
401
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
402

Y
yuyang18 已提交
403 404 405 406 407 408
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
409 410

)DOC");
411 412 413
  }
};

414 415
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
416
  void Make() override {
417
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
418
    AddOutput("Out", "Output of BRelu operator");
419 420 421 422
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
423
    AddComment(R"DOC(
K
kexinzhao 已提交
424
BRelu Activation Operator.
K
Kexin Zhao 已提交
425

F
fengjiayi 已提交
426
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
427 428

)DOC");
429 430 431 432 433
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
434
  void Make() override {
435
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
436
    AddOutput("Out", "Output of SoftRelu operator");
437 438
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
441

T
tensor-tang 已提交
442
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
443 444

)DOC");
445 446 447
  }
};

448 449
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
450
  void Make() override {
K
Kexin Zhao 已提交
451
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
452
    AddOutput("Out", "Output of ELU operator");
453
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
454
    AddComment(R"DOC(
K
kexinzhao 已提交
455
ELU Activation Operator.
K
Kexin Zhao 已提交
456 457 458 459

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
460
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
461 462

)DOC");
463 464 465
  }
};

466 467
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
468
  void Make() override {
469
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
470
    AddOutput("Out", "Output of Relu6 operator");
471 472
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
473
    AddComment(R"DOC(
K
kexinzhao 已提交
474
Relu6 Activation Operator.
K
Kexin Zhao 已提交
475

F
fengjiayi 已提交
476
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
477 478

)DOC");
479 480 481
  }
};

482 483
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
484
  void Make() override {
485
    AddInput("X", "Input of Pow operator");
486 487 488 489 490
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
491
    AddOutput("Out", "Output of Pow operator");
492
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
493
    AddComment(R"DOC(
K
kexinzhao 已提交
494
Pow Activation Operator.
K
Kexin Zhao 已提交
495

F
fengjiayi 已提交
496
$out = x^{factor}$
K
Kexin Zhao 已提交
497 498

)DOC");
499 500 501 502 503
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
504
  void Make() override {
505 506 507 508 509 510
    AddInput("X",
             "Input of STanh operator."
             " A LoDTensor or Tensor with type float32, float64.");
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
511 512
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
513
    AddComment(R"DOC(
K
kexinzhao 已提交
514
STanh Activation Operator.
K
Kexin Zhao 已提交
515

Y
Yan Chunwei 已提交
516
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
517 518

)DOC");
Q
qijun 已提交
519 520 521
  }
};

522 523
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
524
  void Make() override {
525
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
526
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
527 528
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
529
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
530
    AddComment(R"DOC(
Y
yuyang18 已提交
531
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
532

Y
yuyang18 已提交
533
..  math::
K
Kexin Zhao 已提交
534

Y
yuyang18 已提交
535
    out = \begin{cases}
Y
yuyang18 已提交
536
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
537 538
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
539
)DOC");
540 541 542
  }
};

543 544
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
545
  void Make() override {
546
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
547
    AddOutput("Out", "Output of HardSigmoid operator");
548 549 550 551
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
552
    AddComment(R"DOC(
K
kexinzhao 已提交
553
HardSigmoid Activation Operator.
554

555
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
556
which is much faster than sigmoid.
557

F
fengjiayi 已提交
558
$out = \max(0, \min(1, slope * x + shift))$
559 560

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
561
The default slope and shift are set according to the above reference.
562 563
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
564
)DOC");
565 566 567
  }
};

A
Abhinav Arora 已提交
568 569
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
570
  void Make() override {
A
Abhinav Arora 已提交
571
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
572
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
573 574 575 576
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

577
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
578 579 580 581 582

)DOC");
  }
};

H
huangjun12 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
609 610 611 612
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
613
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
614 615 616
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
617
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
618 619 620 621 622 623 624 625 626 627 628 629
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

630
template <ActBwdOpFwdDeps kDepValue>
631 632 633 634 635
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
636
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
637
      if (ctx->HasOutput("DX")) {
638 639 640
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
641
      if (ctx->HasOutput("DDOut")) {
642 643 644
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
645
    }
646
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
647
      if (ctx->HasOutput("DOut")) {
648 649 650
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
679 680 681
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
682 683 684 685 686 687 688 689 690 691
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

692 693 694 695 696 697 698 699 700 701 702 703 704 705
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
706
    // input2: ddx
707 708
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
709
    // output: ddy
710 711 712 713 714
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

715 716 717 718 719 720 721 722 723 724 725
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
726 727
    // input1: Out
    op->SetInput("Out", Input("Out"));
728 729 730 731 732 733 734 735 736
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

784 785 786
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
787 788
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
                           {"DDX", "DDOut"});
789

790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("pow_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetInput("FactorTensor", Input("FactorTensor"));
    op->SetAttrMap(Attrs());

    return op;
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
};
Q
qijun 已提交
859 860 861 862
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
863
namespace plat = paddle::platform;
864

865 866 867 868 869 870 871 872
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
873 874
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
875 876 877

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
878 879 880 881 882 883 884 885 886 887
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
888
                                ops::grad_functor<double>>);
889

890 891
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
892

893
/* ==========================    relu register  ============================= */
894 895 896 897 898
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
899
                  ops::ActivationGradOpInplaceInference,
900
                  ops::ReluDoubleGradMaker);
901 902
REGISTER_OPERATOR(
    relu_grad_grad,
903 904
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
905 906 907 908 909 910 911 912 913 914 915

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
916
/* ========================================================================== */
917

918
/* ======================== leaky relu register  ============================ */
919 920 921 922 923 924
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
925
                  ops::ActivationGradOpInplaceInference,
926
                  ops::LeakyReluDoubleGradMaker);
927 928
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
929 930
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
931

932 933 934 935 936 937 938 939 940 941
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
942 943
/* ========================================================================== */

L
lvmengsi 已提交
944 945 946 947 948 949
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
950
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
951 952 953
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
954 955 956
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);

L
lvmengsi 已提交
957 958 959 960 961 962 963 964 965 966
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

967 968 969 970 971 972 973
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
974
                  ops::ActivationGradOpInplaceInference,
975 976 977
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
978 979
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
    ops::ActivationDoubleGradOpInplaceInference);
980 981 982 983 984 985 986 987 988 989 990 991 992

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011

/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
    pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
    ops::PowGradOpDescMaker,
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
                     ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
                  ops::ActivationGradOpInplaceInference);

REGISTER_OP_CPU_KERNEL(
    pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
    ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
    pow_grad,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
    ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */