activation_op.cc 29.2 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
60
  }
D
dzhwinter 已提交
61

62 63 64 65 66 67 68 69 70 71 72 73 74
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
75 76 77 78
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
79 80 81 82 83 84 85 86 87
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
88
  }
89
};
D
dzhwinter 已提交
90

91 92 93 94
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
95
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
96 97 98 99 100 101 102 103 104 105
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
106 107 108 109 110
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
111
    layout = framework::DataLayout::kMKLDNN;
112 113 114
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
115 116
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
117 118
}

Q
qijun 已提交
119 120 121 122
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

123
  void InferShape(framework::InferShapeContext* ctx) const override {
124
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
125
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
126
  }
127

128
 protected:
129 130 131 132
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
133 134
};

C
chengduo 已提交
135 136 137 138 139 140
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
141 142 143
  }
};

Q
qijun 已提交
144 145 146 147
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

148
  void InferShape(framework::InferShapeContext* ctx) const override {
149 150 151
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
152
  }
153

154
 protected:
155 156
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
157
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
158
  }
Q
qijun 已提交
159 160
};

D
dzhwinter 已提交
161
UNUSED constexpr char SigmoidDoc[] = R"DOC(
162
Sigmoid Activation Operator
K
Kexin Zhao 已提交
163

164
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
165

D
dzhwinter 已提交
166
)DOC";
Q
qijun 已提交
167

D
dzhwinter 已提交
168
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
169
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
170

171
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
176
Exp Activation Operator.
K
Kexin Zhao 已提交
177

F
fengjiayi 已提交
178
$out = e^x$
K
Kexin Zhao 已提交
179

D
dzhwinter 已提交
180
)DOC";
Q
qijun 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
183
Relu Activation Operator.
K
Kexin Zhao 已提交
184

F
fengjiayi 已提交
185
$out = \max(x, 0)$
K
Kexin Zhao 已提交
186

D
dzhwinter 已提交
187
)DOC";
K
Kexin Zhao 已提交
188

C
Clementine 已提交
189 190 191 192 193 194 195
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
196
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
197
Tanh Activation Operator.
K
Kexin Zhao 已提交
198

Q
update  
qiaolongfei 已提交
199
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
200

D
dzhwinter 已提交
201
)DOC";
202

D
dzhwinter 已提交
203
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
204
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
205

Y
Yan Chunwei 已提交
206
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
207

D
dzhwinter 已提交
208
)DOC";
K
Kexin Zhao 已提交
209

D
dzhwinter 已提交
210
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
211
Sqrt Activation Operator.
K
Kexin Zhao 已提交
212

213 214 215
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
216
$out = \sqrt{x}$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
219

Z
zhoukunsheng 已提交
220 221 222 223 224 225 226 227 228
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
229
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
230
Abs Activation Operator.
K
Kexin Zhao 已提交
231

F
fengjiayi 已提交
232
$out = |x|$
K
Kexin Zhao 已提交
233

D
dzhwinter 已提交
234
)DOC";
235

D
dzhwinter 已提交
236
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
237 238
Ceil Activation Operator.

239
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
240

D
dzhwinter 已提交
241
)DOC";
D
dzhwinter 已提交
242

D
dzhwinter 已提交
243
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
244 245
Floor Activation Operator.

246
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
247

D
dzhwinter 已提交
248
)DOC";
D
dzhwinter 已提交
249

D
dzhwinter 已提交
250
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
251
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
252 253 254

$out = cos(x)$

D
dzhwinter 已提交
255
)DOC";
C
add cos  
chengduoZH 已提交
256

D
dzhwinter 已提交
257
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
258 259 260 261
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
262
)DOC";
C
add sin  
chengduoZH 已提交
263

D
dzhwinter 已提交
264
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
265 266
Round Activation Operator.

F
fengjiayi 已提交
267
$out = [x]$
D
dzhwinter 已提交
268

D
dzhwinter 已提交
269
)DOC";
D
dzhwinter 已提交
270

D
dzhwinter 已提交
271
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
272
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
273

274
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
275

D
dzhwinter 已提交
276
)DOC";
277

D
dzhwinter 已提交
278
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
279
Log Activation Operator.
K
Kexin Zhao 已提交
280

F
fengjiayi 已提交
281
$out = \ln(x)$
K
Kexin Zhao 已提交
282 283 284

Natural logarithm of x.

D
dzhwinter 已提交
285 286
)DOC";

D
dzhwinter 已提交
287
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
288 289 290
Square Activation Operator.

$out = x^2$
291

D
dzhwinter 已提交
292 293
)DOC";

D
dzhwinter 已提交
294
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
295 296 297 298 299 300
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
301
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
302 303
Softsign Activation Operator.

304
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
305 306 307

)DOC";

T
tink2123 已提交
308 309 310 311 312 313
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
314 315
Arccosine Activation Operator.

T
tink2123 已提交
316
$$out = \cos^{-1}(x)$$
317

T
tink2123 已提交
318 319 320
)DOC");
  }
};
321

T
tink2123 已提交
322 323 324 325 326 327
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
328 329
Arcsine Activation Operator.

T
tink2123 已提交
330
$$out = \sin^{-1}(x)$$
331

T
tink2123 已提交
332 333 334
)DOC");
  }
};
335

T
tink2123 已提交
336 337 338 339 340 341
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
342 343
Arctanh Activation Operator.

T
tink2123 已提交
344
$$out = \tanh^{-1}(x)$$
345

T
tink2123 已提交
346 347 348
)DOC");
  }
};
349

D
dzhwinter 已提交
350
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
351
 public:
Y
Yu Yang 已提交
352
  void Make() override {
D
dzhwinter 已提交
353 354 355
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
356 357 358 359 360 361 362
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
363
    AddComment(R"DOC(
D
dzhwinter 已提交
364
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
365

D
dzhwinter 已提交
366
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
367 368

)DOC");
369 370 371
  }
};

D
dzhwinter 已提交
372
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
373
 public:
Y
Yu Yang 已提交
374
  void Make() override {
D
dzhwinter 已提交
375 376 377
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
378
    AddComment(R"DOC(
379 380 381
:strong:`Softshrink Activation Operator`

..  math::
382
    out = \begin{cases}
383 384 385 386
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
387 388

)DOC");
K
kexinzhao 已提交
389 390 391
  }
};

D
dzhwinter 已提交
392
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
393
 public:
Y
Yu Yang 已提交
394
  void Make() override {
D
dzhwinter 已提交
395 396
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
397 398
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
399
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
400
    AddComment(R"DOC(
Y
yuyang18 已提交
401
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
402

Y
yuyang18 已提交
403 404 405 406 407 408
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
409 410

)DOC");
411 412 413
  }
};

414 415
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
416
  void Make() override {
417
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
418
    AddOutput("Out", "Output of BRelu operator");
419 420 421 422
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
423
    AddComment(R"DOC(
K
kexinzhao 已提交
424
BRelu Activation Operator.
K
Kexin Zhao 已提交
425

F
fengjiayi 已提交
426
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
427 428

)DOC");
429 430 431 432 433
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
434
  void Make() override {
435
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
436
    AddOutput("Out", "Output of SoftRelu operator");
437 438
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
441

T
tensor-tang 已提交
442
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
443 444

)DOC");
445 446 447
  }
};

448 449
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
450
  void Make() override {
K
Kexin Zhao 已提交
451
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
452
    AddOutput("Out", "Output of ELU operator");
453
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
454
    AddComment(R"DOC(
K
kexinzhao 已提交
455
ELU Activation Operator.
K
Kexin Zhao 已提交
456 457 458 459

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
460
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
461 462

)DOC");
463 464 465
  }
};

466 467
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
468
  void Make() override {
469
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
470
    AddOutput("Out", "Output of Relu6 operator");
471 472
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
473
    AddComment(R"DOC(
K
kexinzhao 已提交
474
Relu6 Activation Operator.
K
Kexin Zhao 已提交
475

F
fengjiayi 已提交
476
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
477 478

)DOC");
479 480 481
  }
};

482 483
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
484
  void Make() override {
485
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
486
    AddOutput("Out", "Output of Pow operator");
487
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
488
    AddComment(R"DOC(
K
kexinzhao 已提交
489
Pow Activation Operator.
K
Kexin Zhao 已提交
490

F
fengjiayi 已提交
491
$out = x^{factor}$
K
Kexin Zhao 已提交
492 493

)DOC");
494 495 496 497 498
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
499
  void Make() override {
500
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
501
    AddOutput("Out", "Output of STanh operator");
502 503 504 505
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
506
    AddComment(R"DOC(
K
kexinzhao 已提交
507
STanh Activation Operator.
K
Kexin Zhao 已提交
508

Y
Yan Chunwei 已提交
509
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
510 511

)DOC");
Q
qijun 已提交
512 513 514
  }
};

515 516
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
517
  void Make() override {
518
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
519
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
520 521
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
522
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
523
    AddComment(R"DOC(
Y
yuyang18 已提交
524
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
525

Y
yuyang18 已提交
526
..  math::
K
Kexin Zhao 已提交
527

Y
yuyang18 已提交
528
    out = \begin{cases}
Y
yuyang18 已提交
529
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
530 531
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
532
)DOC");
533 534 535
  }
};

536 537
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
538
  void Make() override {
539
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
540
    AddOutput("Out", "Output of HardSigmoid operator");
541 542 543 544
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
545
    AddComment(R"DOC(
K
kexinzhao 已提交
546
HardSigmoid Activation Operator.
547

548
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
549
which is much faster than sigmoid.
550

F
fengjiayi 已提交
551
$out = \max(0, \min(1, slope * x + shift))$
552 553

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
554
The default slope and shift are set according to the above reference.
555 556
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
557
)DOC");
558 559 560
  }
};

A
Abhinav Arora 已提交
561 562
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
563
  void Make() override {
A
Abhinav Arora 已提交
564
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
565
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
566 567 568 569
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

570
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
571 572 573 574 575

)DOC");
  }
};

H
huangjun12 已提交
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

D
dzhwinter 已提交
602 603 604 605
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
606
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
607 608 609
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
610
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
611 612 613 614 615 616 617 618 619 620 621 622
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

623
template <ActBwdOpFwdDeps kDepValue>
624 625 626 627 628
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
629
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
630
      if (ctx->HasOutput("DX")) {
631 632 633
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
634
      if (ctx->HasOutput("DDOut")) {
635 636 637
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
638
    }
639
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
640
      if (ctx->HasOutput("DOut")) {
641 642 643
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
672 673 674
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
675 676 677 678 679 680 681 682 683 684
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

685 686 687 688 689 690 691 692 693 694 695 696 697 698
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
699
    // input2: ddx
700 701
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
702
    // output: ddy
703 704 705 706 707
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

708 709 710 711 712 713 714 715 716 717 718
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
Z
Zeng Jinle 已提交
719 720
    // input1: Out
    op->SetInput("Out", Input("Out"));
721 722 723 724 725 726 727 728 729
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

777 778 779
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
780

Q
qijun 已提交
781 782 783 784
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
785
namespace plat = paddle::platform;
786

787 788 789 790 791 792 793 794
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
795 796
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
797 798 799

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
800 801 802 803 804 805 806 807 808 809
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
810
                                ops::grad_functor<double>>);
811

812 813
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
814

815
/* ==========================    relu register  ============================= */
816 817 818 819 820
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
821
                  ops::ActivationGradOpInplaceInference,
822
                  ops::ReluDoubleGradMaker);
823 824
REGISTER_OPERATOR(
    relu_grad_grad,
825
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
826 827 828 829 830 831 832 833 834 835 836

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
837
/* ========================================================================== */
838

839
/* ======================== leaky relu register  ============================ */
840 841 842 843 844 845
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
846
                  ops::ActivationGradOpInplaceInference,
847
                  ops::LeakyReluDoubleGradMaker);
848 849
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
850
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
851

852 853 854 855 856 857 858 859 860 861
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
862 863
/* ========================================================================== */

L
lvmengsi 已提交
864 865 866 867 868 869
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
870
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
871 872 873 874 875 876 877 878 879 880 881 882 883 884
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

885 886 887 888 889 890 891
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
892
                  ops::ActivationGradOpInplaceInference,
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */