activation_op.cc 28.4 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
60
  }
D
dzhwinter 已提交
61

62 63 64 65 66 67 68 69 70 71 72 73 74
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
75 76 77 78
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
79 80 81 82 83 84 85 86 87
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
88
  }
89
};
D
dzhwinter 已提交
90

91 92 93 94
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
95
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
96 97 98 99 100 101 102 103 104 105
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
106 107 108 109 110
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
111
    layout = framework::DataLayout::kMKLDNN;
112 113 114
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
115 116
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
117 118
}

Q
qijun 已提交
119 120 121 122
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

123
  void InferShape(framework::InferShapeContext* ctx) const override {
124
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
125
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
126
  }
127

128
 protected:
129 130 131 132
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
133 134
};

C
chengduo 已提交
135 136 137 138 139 140
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
141 142 143
  }
};

Q
qijun 已提交
144 145 146 147
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

148
  void InferShape(framework::InferShapeContext* ctx) const override {
149 150 151
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
152
  }
153

154
 protected:
155 156
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
157
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
158
  }
Q
qijun 已提交
159 160
};

D
dzhwinter 已提交
161
UNUSED constexpr char SigmoidDoc[] = R"DOC(
162
Sigmoid Activation Operator
K
Kexin Zhao 已提交
163

164
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
165

D
dzhwinter 已提交
166
)DOC";
Q
qijun 已提交
167

D
dzhwinter 已提交
168
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
169
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
170

171
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
176
Exp Activation Operator.
K
Kexin Zhao 已提交
177

F
fengjiayi 已提交
178
$out = e^x$
K
Kexin Zhao 已提交
179

D
dzhwinter 已提交
180
)DOC";
Q
qijun 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
183
Relu Activation Operator.
K
Kexin Zhao 已提交
184

F
fengjiayi 已提交
185
$out = \max(x, 0)$
K
Kexin Zhao 已提交
186

D
dzhwinter 已提交
187
)DOC";
K
Kexin Zhao 已提交
188

C
Clementine 已提交
189 190 191 192 193 194 195
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
196
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
197
Tanh Activation Operator.
K
Kexin Zhao 已提交
198

Q
update  
qiaolongfei 已提交
199
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
200

D
dzhwinter 已提交
201
)DOC";
202

D
dzhwinter 已提交
203
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
204
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
205

Y
Yan Chunwei 已提交
206
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
207

D
dzhwinter 已提交
208
)DOC";
K
Kexin Zhao 已提交
209

D
dzhwinter 已提交
210
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
211
Sqrt Activation Operator.
K
Kexin Zhao 已提交
212

213 214 215
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
216
$out = \sqrt{x}$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
219

Z
zhoukunsheng 已提交
220 221 222 223 224 225 226 227 228
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
229
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
230
Abs Activation Operator.
K
Kexin Zhao 已提交
231

F
fengjiayi 已提交
232
$out = |x|$
K
Kexin Zhao 已提交
233

D
dzhwinter 已提交
234
)DOC";
235

D
dzhwinter 已提交
236
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
237 238
Ceil Activation Operator.

239
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
240

D
dzhwinter 已提交
241
)DOC";
D
dzhwinter 已提交
242

D
dzhwinter 已提交
243
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
244 245
Floor Activation Operator.

246
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
247

D
dzhwinter 已提交
248
)DOC";
D
dzhwinter 已提交
249

D
dzhwinter 已提交
250
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
251
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
252 253 254

$out = cos(x)$

D
dzhwinter 已提交
255
)DOC";
C
add cos  
chengduoZH 已提交
256

D
dzhwinter 已提交
257
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
258 259 260 261
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
262
)DOC";
C
add sin  
chengduoZH 已提交
263

D
dzhwinter 已提交
264
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
265 266
Round Activation Operator.

F
fengjiayi 已提交
267
$out = [x]$
D
dzhwinter 已提交
268

D
dzhwinter 已提交
269
)DOC";
D
dzhwinter 已提交
270

D
dzhwinter 已提交
271
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
272
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
273

274
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
275

D
dzhwinter 已提交
276
)DOC";
277

D
dzhwinter 已提交
278
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
279
Log Activation Operator.
K
Kexin Zhao 已提交
280

F
fengjiayi 已提交
281
$out = \ln(x)$
K
Kexin Zhao 已提交
282 283 284

Natural logarithm of x.

D
dzhwinter 已提交
285 286
)DOC";

D
dzhwinter 已提交
287
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
288 289 290
Square Activation Operator.

$out = x^2$
291

D
dzhwinter 已提交
292 293
)DOC";

D
dzhwinter 已提交
294
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
295 296 297 298 299 300
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
301
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
302 303
Softsign Activation Operator.

304
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
305 306 307

)DOC";

T
tink2123 已提交
308 309 310 311 312 313
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
314 315
Arccosine Activation Operator.

T
tink2123 已提交
316
$$out = \cos^{-1}(x)$$
317

T
tink2123 已提交
318 319 320
)DOC");
  }
};
321

T
tink2123 已提交
322 323 324 325 326 327
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
328 329
Arcsine Activation Operator.

T
tink2123 已提交
330
$$out = \sin^{-1}(x)$$
331

T
tink2123 已提交
332 333 334
)DOC");
  }
};
335

T
tink2123 已提交
336 337 338 339 340 341
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
342 343
Arctanh Activation Operator.

T
tink2123 已提交
344
$$out = \tanh^{-1}(x)$$
345

T
tink2123 已提交
346 347 348
)DOC");
  }
};
349

D
dzhwinter 已提交
350
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
351
 public:
Y
Yu Yang 已提交
352
  void Make() override {
D
dzhwinter 已提交
353 354 355
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
356 357 358 359 360 361 362
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
363
    AddComment(R"DOC(
D
dzhwinter 已提交
364
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
365

D
dzhwinter 已提交
366
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
367 368

)DOC");
369 370 371
  }
};

D
dzhwinter 已提交
372
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
373
 public:
Y
Yu Yang 已提交
374
  void Make() override {
D
dzhwinter 已提交
375 376 377
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
378
    AddComment(R"DOC(
379 380 381
:strong:`Softshrink Activation Operator`

..  math::
382
    out = \begin{cases}
383 384 385 386
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
387 388

)DOC");
K
kexinzhao 已提交
389 390 391
  }
};

D
dzhwinter 已提交
392
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
393
 public:
Y
Yu Yang 已提交
394
  void Make() override {
D
dzhwinter 已提交
395 396
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
397 398
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
399
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
400
    AddComment(R"DOC(
Y
yuyang18 已提交
401
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
402

Y
yuyang18 已提交
403 404 405 406 407 408
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
409 410

)DOC");
411 412 413
  }
};

414 415
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
416
  void Make() override {
417
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
418
    AddOutput("Out", "Output of BRelu operator");
419 420 421 422
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
423
    AddComment(R"DOC(
K
kexinzhao 已提交
424
BRelu Activation Operator.
K
Kexin Zhao 已提交
425

F
fengjiayi 已提交
426
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
427 428

)DOC");
429 430 431 432 433
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
434
  void Make() override {
435
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
436
    AddOutput("Out", "Output of SoftRelu operator");
437 438
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
441

T
tensor-tang 已提交
442
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
443 444

)DOC");
445 446 447
  }
};

448 449
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
450
  void Make() override {
K
Kexin Zhao 已提交
451
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
452
    AddOutput("Out", "Output of ELU operator");
453
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
454
    AddComment(R"DOC(
K
kexinzhao 已提交
455
ELU Activation Operator.
K
Kexin Zhao 已提交
456 457 458 459

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
460
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
461 462

)DOC");
463 464 465
  }
};

466 467
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
468
  void Make() override {
469
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
470
    AddOutput("Out", "Output of Relu6 operator");
471 472
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
473
    AddComment(R"DOC(
K
kexinzhao 已提交
474
Relu6 Activation Operator.
K
Kexin Zhao 已提交
475

F
fengjiayi 已提交
476
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
477 478

)DOC");
479 480 481
  }
};

482 483
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
484
  void Make() override {
485
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
486
    AddOutput("Out", "Output of Pow operator");
487
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
488
    AddComment(R"DOC(
K
kexinzhao 已提交
489
Pow Activation Operator.
K
Kexin Zhao 已提交
490

F
fengjiayi 已提交
491
$out = x^{factor}$
K
Kexin Zhao 已提交
492 493

)DOC");
494 495 496 497 498
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
499
  void Make() override {
500
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
501
    AddOutput("Out", "Output of STanh operator");
502 503 504 505
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
506
    AddComment(R"DOC(
K
kexinzhao 已提交
507
STanh Activation Operator.
K
Kexin Zhao 已提交
508

Y
Yan Chunwei 已提交
509
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
510 511

)DOC");
Q
qijun 已提交
512 513 514
  }
};

515 516
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
517
  void Make() override {
518
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
519
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
520 521
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
522
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
523
    AddComment(R"DOC(
Y
yuyang18 已提交
524
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
525

Y
yuyang18 已提交
526
..  math::
K
Kexin Zhao 已提交
527

Y
yuyang18 已提交
528
    out = \begin{cases}
Y
yuyang18 已提交
529
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
530 531
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
532
)DOC");
533 534 535
  }
};

536 537
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
538
  void Make() override {
539
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
540
    AddOutput("Out", "Output of HardSigmoid operator");
541 542 543 544
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
545
    AddComment(R"DOC(
K
kexinzhao 已提交
546
HardSigmoid Activation Operator.
547

548
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
549
which is much faster than sigmoid.
550

F
fengjiayi 已提交
551
$out = \max(0, \min(1, slope * x + shift))$
552 553

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
554
The default slope and shift are set according to the above reference.
555 556
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
557
)DOC");
558 559 560
  }
};

A
Abhinav Arora 已提交
561 562
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
563
  void Make() override {
A
Abhinav Arora 已提交
564
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
565
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
566 567 568 569
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

570
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
571 572 573 574 575

)DOC");
  }
};

D
dzhwinter 已提交
576 577 578 579
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
580
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
581 582 583
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
584
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
585 586 587 588 589 590 591 592 593 594 595 596
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

597
template <ActBwdOpFwdDeps kDepValue>
598 599 600 601 602
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
603
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
604
      if (ctx->HasOutput("DX")) {
605 606 607
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
608
      if (ctx->HasOutput("DDOut")) {
609 610 611
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
612
    }
613
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
614
      if (ctx->HasOutput("DOut")) {
615 616 617
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
646 647 648
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
649 650 651 652 653 654 655 656 657 658
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

659 660 661 662 663 664 665 666 667 668 669 670 671 672
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
673
    // input2: ddx
674 675
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
676
    // output: ddy
677 678 679 680 681
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
    // input1: X
    op->SetInput("X", Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

751 752 753 754 755 756 757 758
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const framework::OpDesc& op_desc, bool use_cuda) const override {
    return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
  }
};

Q
qijun 已提交
759 760 761 762
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
763
namespace plat = paddle::platform;
764

765 766 767 768 769 770 771 772
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
773 774
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
775 776 777

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
778 779 780 781 782 783 784 785 786 787
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
788
                                ops::grad_functor<double>>);
789

790 791
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
792

793
/* ==========================    relu register  ============================= */
794 795 796 797 798
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
799
                  ops::ActivationGradOpInplaceInference,
800
                  ops::ReluDoubleGradMaker);
801 802
REGISTER_OPERATOR(
    relu_grad_grad,
803
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
804 805 806 807 808 809 810 811 812 813 814

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
815
/* ========================================================================== */
816

817
/* ======================== leaky relu register  ============================ */
818 819 820 821 822 823
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
824
                  ops::ActivationGradOpInplaceInference,
825
                  ops::LeakyReluDoubleGradMaker);
826 827
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
828
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
829

830 831 832 833 834 835 836 837 838 839
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
840 841
/* ========================================================================== */

L
lvmengsi 已提交
842 843 844 845 846 847
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
848
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
849 850 851 852 853 854 855 856 857 858 859 860 861 862
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

863 864 865 866 867 868 869
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
870
                  ops::ActivationGradOpInplaceInference,
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */